import logging
import datetime
import pandas as pd
import ujson
from MySQLdb import escape_string

from addict import Dict
from django.db import models
from django.db import connections
from django.db.models.fields import AutoField
from django.db.models.fields import BooleanField
from django.db.models.fields import CharField
from django.db.models.fields import DateTimeField
from django.db.models.fields import DecimalField
from django.db.models.fields import IntegerField
from django.db.models.fields.files import FileField
from django.db.models.fields import FloatField
from django.utils.functional import Promise
from common.utils.gettimes import GetRunTime
from common.utils.exceptions import ValidationError

logger = logging.getLogger("")


class BaseConnection(object):
    def __init__(self, db_alias='default', *args, **kwargs):
        self.db_alias = db_alias

    @GetRunTime
    def _sql_query(self, sql):
        logger.debug(sql)
        cursor = connections[self.db_alias].cursor()
        cursor.execute(sql)
        desc = cursor.description
        rows = cursor.fetchall()
        cols = [col[0] for col in desc]
        cursor.close()
        connections[self.db_alias].close()
        return cols, rows

    def _query(self, sql):
        cols, rows = self._sql_query(sql)
        df = pd.DataFrame(columns=cols, data=rows)
        df = df.where(df.notnull(), None)
        return df

    def _dict_fetch_all(self, cursor):
        desc = cursor.description
        return [
            dict(zip([col[0] for col in desc], row))
            for row in cursor.fetchall()
        ]

    def _list_fetch_all(self, cursor):
        return cursor.fetchall()

    def query(self, sql):

        cursor = connections[self.db_alias].cursor()
        cursor.execute(sql)
        query_set = self._dict_fetch_all(cursor)
        cursor.close()
        connections[self.db_alias].close()
        return query_set

    def query_list(self, sql):

        cursor = connections[self.db_alias].cursor()
        cursor.execute(sql)
        query_set = self._list_fetch_all(cursor)
        cursor.close()
        connections[self.db_alias].close()
        return query_set

    def get_insert_sql(self, table, sql):
        cols, rows = self._sql_query(sql)
        row_list = []
        for row in rows:
            i_r = []
            for i in row:
                if isinstance(i, datetime.datetime):
                    i_r.append(i.strftime("%Y-%m-%d %H:%M:%S"))
                else:
                    i_r.append(i)
            row_list.append("(%s)" % ujson.dumps(i_r, ensure_ascii=False).strip("[").strip("]"))
        insert_sql = """
        INSERT INTO `{}`
        ({})
        VALUES
        {}
        """.format(
            table, ",".join(["`%s`" % i for i in cols]),
            ",".join(row_list))
        return insert_sql

    def get_insert_or_update_sql(self, table, sql):
        cols, rows = self._sql_query(sql)
        duplicate_to_update = """
        ON DUPLICATE KEY UPDATE %s
        """ % ",".join(["`%s`=values(`%s`)" % (i, i) for i in cols])
        row_list = []
        for row in rows:
            i_r = []
            for i in row:
                if isinstance(i, datetime.datetime):
                    i_r.append(i.strftime("%Y-%m-%d %H:%M:%S"))
                else:
                    i_r.append(i)
            row_list.append("(%s)" % ujson.dumps(i_r, ensure_ascii=False).strip("[").strip("]"))
        # insert_or_update
        if not row_list:
            return None
        insert_sql = """
        INSERT INTO `{}`
        ({})
        VALUES
        {} {}
        """.format(
            table, ",".join(["`%s`" % i for i in cols]),
            ",".join(row_list), duplicate_to_update)
        return insert_sql

    def structure_tree(self, dt, id, index="id", pid="pid", sub="sub"):
        """
        :param dt: 数据列表
        :param id: 起始id
        :param index: 主键列名
        :param pid: 关联父子节点列名
        :param sub: 子节点字段名称
        :return:
        """
        next_dt = []
        tree = []
        next_pids = []
        tree_map = {}
        if dt:
            if index not in dt[0] or pid not in dt[0]:
                raise ValidationError("index 或 pid 传入的列名不存在, 构建树状结构失败")
        if isinstance(id, (int, str)):
            id = (id,)
        if isinstance(id, (tuple, list)):
            for i in dt:
                if i[pid] in id:
                    tree_map.setdefault(i[pid], [])
                    tree_map[i[pid]].append(i)
                    next_pids.append(i[index])
                    tree.append(i)
                else:
                    next_dt.append(i)
            if tree:
                _, next_tree_map = self.structure_tree(next_dt, next_pids,
                                                       index, pid, sub)
                for j in range(len(tree)):
                    tree[j][sub] = []
                    s_id = tree[j][index]
                    tree[j][sub] = next_tree_map.get(s_id, [])
        return tree, tree_map

    @GetRunTime
    def df_query(self, sql):
        df = self._query(sql)
        return df.to_dict(orient="records")

    @GetRunTime
    def query_df(self, sql):
        df = self._query(sql)
        return df

    def paginator(cls, page, page_size, get_all=False):
        if isinstance(page, list):
            page = page[0]
        if isinstance(page_size, list):
            page_size = page_size[0]
        if page and page_size:
            pre_page = int(page)
            pre_page_size = int(page_size)
            page = (pre_page - 1) * pre_page_size

            return "{page},{page_size}".format(page=page, page_size=pre_page_size)
        else:
            if get_all:
                return ""
            else:
                return "0,10"

    def filter(self, key, value, func=None, multi=False):
        or_query_list = []
        if isinstance(key, list):
            if not multi:
                key = key[0]
        if not key:
            return []
        if isinstance(key, list):
            for i in key:
                if func:
                    or_query_list.append(
                        value.format(escape_string(func(i)).decode()))
                else:
                    or_query_list.append(
                        value.format(escape_string(i).decode()))
            return [" AND ".join(or_query_list)]
        else:
            if func:
                return [value.format(escape_string(func(str(key))).decode())]
            else:
                return [value.format(escape_string(str(key)).decode())]

    def order_by(self, sort_list, map=None, default=None):
        if not sort_list:
            return default
        order_list = []
        for i in sort_list:
            if isinstance(i, list):
                order_list.append(self.order_by(i, map))
            else:
                s_v = i.split("-")
                vl = map.get(s_v[-1])
                if isinstance(vl, str):
                    if len(s_v) > 1:
                        order_list.append(" %s DESC " % vl)
                    else:
                        order_list.append(" %s " % vl)
        return ",".join(order_list)

    def sql_splice(self,
                   _sql,
                   select,
                   where=None,
                   count_sql=None,
                   group_by=None,
                   order_by=None,
                   limit_sql=None,
                   limit=None,
                   limit_idx=None):
        """
        :param _sql: 缺少select参数的sql语句
        :param select: select参数
        :param where: sql where查询语句列表
        :param count_sql: 计数sql
        :param order_by: 排序sql
        :param limit: 分页逻辑  "1,10"
        :param limit_idx: 优化分页性能查询字段，主键或被设为index的字段
        :return: 拼装好的sql语句，计数结果
        """
        count = None
        _limit = None
        sql = _sql % select
        if not limit_sql:
            limit_sql = _sql
        if limit and limit_idx:
            if where:
                limit_sql = ' WHERE '.join([limit_sql, ' and '.join(where)])
                limit_sql = limit_sql % limit_idx
            else:
                limit_sql = _sql % limit_idx
            _limit = limit.split(",")
        if count_sql:
            print(count_sql)
            if self.query(count_sql):
                ct = self.query(count_sql)
                ctl = len(ct)
                if ctl == 1:
                    count = ct[0]["count"]
                else:
                    count = ctl
        if where:
            sql = ' WHERE '.join([sql, ' and '.join(where)])
            _sql = ' WHERE '.join([_sql, ' and '.join(where)])
        if _limit and limit_idx:
            if group_by:
                limit_sql = ' GROUP BY '.join([limit_sql, group_by])
            if order_by:
                limit_sql = ' ORDER BY '.join([limit_sql, order_by])
                if order_by == limit_idx:
                    sql = """%s AND %s >=( %s limit %s,1) """ % (
                        sql, limit_idx, limit_sql, _limit[0])
            else:
                sql = """%s AND %s >=( %s limit %s,1) """ % (
                    sql, limit_idx, limit_sql, _limit[0])
        if group_by:
            sql = ' GROUP BY '.join([sql, group_by])
        if order_by:
            if order_by == limit_idx:
                if _limit:
                    sql = ' LIMIT '.join([sql, _limit[-1]])
            else:
                sql = ' ORDER BY '.join([sql, order_by])
                if _limit:
                    sql = ' LIMIT '.join([sql, ','.join(_limit)])
        else:
            if _limit:
                sql = ' LIMIT '.join([sql, _limit[-1]])
        return sql, count

    def json_loads_f(self, x):
        if x:
            try:
                return ujson.loads(x)
            except Exception as e:
                print(x)
                raise e
        else:
            return []


class PowerOAConnection(BaseConnection):
    sql = ""

    def __init__(self, *args, **kwargs):
        super(PowerOAConnection, self).__init__(db_alias='power_oa',
                                                *args,
                                                **kwargs)

    def paginator(cls, page, page_size):
        if page and page_size:
            pre_page = int(page)
            pre_page_size = int(page_size)
            page = (pre_page - 1) * pre_page_size
            return "{page}, {page_size}".format(page=page, page_size=pre_page_size)
        return ""


class TestConnection(BaseConnection):
    sql = ""

    def __init__(self, *args, **kwargs):
        super(TestConnection, self).__init__(db_alias='default_test',
                                             *args,
                                             **kwargs)

    def paginator(cls, page, page_size):
        if page and page_size:
            pre_page = int(page)
            pre_page_size = int(page_size)
            page = (pre_page - 1) * pre_page_size
            return "{page}, {page_size}".format(page=page, page_size=page_size)
        return ""


class MonitorOAConnection(BaseConnection):
    sql = ""

    def __init__(self, *args, **kwargs):
        super(MonitorOAConnection, self).__init__(db_alias='monitor_oa',
                                                  *args,
                                                  **kwargs)

    def paginator(cls, page, page_size):
        if page and page_size:
            pre_page = int(page)
            pre_page_size = int(page_size)
            page = (pre_page - 1) * pre_page_size
            return "{page}, {page_size}".format(page=page, page_size=page_size)
        return ""


class MonitorConnection(BaseConnection):
    sql = ""

    def __init__(self, *args, **kwargs):
        super(MonitorConnection, self).__init__(db_alias='monitor',
                                                *args,
                                                **kwargs)

    def paginator(cls, page, page_size):
        if page and page_size:
            pre_page = int(page)
            pre_page_size = int(page_size)
            page = (pre_page - 1) * pre_page_size
            return "{page}, {page_size}".format(page=page, page_size=page_size)
        return ""


class ScoreConnection(BaseConnection):
    sql = ""

    def __init__(self, *args, **kwargs):
        super(ScoreConnection, self).__init__(db_alias='score',
                                              *args,
                                              **kwargs)

    def paginator(cls, page, page_size):
        if page and page_size:
            pre_page = int(page)
            pre_page_size = int(page_size)
            page = (pre_page - 1) * pre_page_size
            return "{page}, {page_size}".format(page=page, page_size=page_size)
        return ""


class BaseModel(models.Model):
    id = models.AutoField(primary_key=True, max_length=10)
    deleted = models.BooleanField(default=False)
    created_time = models.DateTimeField(auto_now_add=True, null=True)
    updated_time = models.DateTimeField(auto_now=True, null=True)
    deleted_time = models.DateTimeField(null=True)

    class Meta:
        abstract = True

    # model转字典
    def to_dict(self):
        opts = self._meta
        data = {}
        for f in opts.concrete_fields:
            value = f.value_from_object(self)
            if isinstance(value, type(datetime)):
                value = value.strftime('%Y-%m-%d %H:%M:%S')
            elif isinstance(f, FileField):
                value = value.url if value else None
            data[f.attname] = value
        return data

    # model转字典格式的数据库设计描述
    def to_json(self):
        opts = self._meta
        data = []
        for f in opts.concrete_fields:
            parameter = {}
            parameter["table"] = opts.db_table
            parameter["name"] = f.name
            parameter["kind"] = self.parse_kind(f)
            parameter["default"] = self.parse_default(f.default)
            parameter["desc"] = f.help_text
            data.append(parameter)
        return data

    def parse_default(self, a):
        # if type(a) == NOT_PROVIDED:
        return ""

    def parse_kind(self, a):
        # print(a.name, type(a))
        if type(a) == CharField:
            return "string"
        if type(a) == AutoField:
            return "int"
        if type(a) == BooleanField:
            return "boolean"
        if type(a) == DecimalField:
            return "decimal"
        if type(a) == DateTimeField:
            return "datetime"
        if type(a) == IntegerField:
            return "int"
        return "string"


def model_to_dict(self):
    opts = self._meta
    data = {}
    for f in opts.concrete_fields:
        value = f.value_from_object(self)
        if isinstance(value, type(datetime)):
            value = value.strftime('%Y-%m-%d %H:%M:%S')
        elif isinstance(f, FileField):
            value = value.url if value else None
        data[f.attname] = value
    return Dict(data)


class MyFloatField(FloatField):
    def get_prep_value(self, value):
        if isinstance(value, Promise):
            value = value._proxy____cast()
        if not value:
            return None
        return float(value)
