from django.db.transaction import atomic
from common.utils.models import BaseConnection

from console.equipment.models import Equipment, EquipmentData
from common.utils.exceptions import ValidationError

base_db = BaseConnection()


class EquipmentQuerySet(object):

    def structure_tree(self, dt, id):
        next_dt = []
        tree = []
        next_pids = []
        tree_map = {}
        if isinstance(id, (int, str)):
            id = (id, )
        if isinstance(id, (tuple, list)):
            for i in dt:
                if i["pid"] in id:
                    tree_map.setdefault(i["pid"], [])
                    tree_map[i["pid"]].append(i)
                    next_pids.append(i["id"])
                    tree.append(i)
                else:
                    next_dt.append(i)
            if tree:
                _, next_tree_map = self.structure_tree(next_dt, next_pids)
                for j in range(len(tree)):
                    tree[j]["sub"] = []
                    s_id = tree[j]["id"]
                    tree[j]["sub"] = next_tree_map.get(s_id, [])
        return tree, tree_map

    def get_eq(self, electricity_user_id):
        sql = """
        select
            A.id,
            A.name,
            A.electricity_user_id,
            C.modular,
            C.type,
            C.capacity,
            if(B.equipment_id is null, 0, 1) is_bind,
            B.id point_id,
            B.name point_name
        from equipments A
        left join monitor_points B
        on A.id=B.equipment_id
        left join equipment_data C
        on A.equipment_data_id=C.id
        where A.electricity_user_id={} and C.modular=1
        """.format(electricity_user_id)
        return base_db.query(sql)

    def tree(self, id):
        sql = """
        select
            A.* from tree A
        LEFT JOIN (select group_id from tree where id={}) B
        on A.group_id=B.group_id
        """.format(id)
        rst = base_db.query(sql)
        tree, _ = base_db.structure_tree(rst,
                                         id,
                                         index="id",
                                         pid="pid",
                                         sub="child")
        return tree

    def get_equipments(self, enterprise_ids, **kwargs):
        page = kwargs.get("page")
        page_size = kwargs.get("page_size")
        where = ["1=1", "A.equipment_data_id is not null",
                 "B.service_enterprise_id in (%s)" % ",".join(enterprise_ids)]
        where.extend(
            base_db.filter(kwargs.get("customer_id"),
                           value='A.customer_id={}'))
        where.extend(
            base_db.filter(kwargs.get("electricity_user_id"),
                           value='A.electricity_user_id={}'))
        where.extend(
            base_db.filter(kwargs.get("keyword"),
                           value='concat(ifnull(A.name,""),ifnull(B.name,""),'
                                 'ifnull(C.number,""),ifnull(D.manufacturer,""),'
                                 'ifnull(D.model,"")) LIKE "%%{}%%"'))
        where.extend(
            base_db.filter(kwargs.get("modular"),
                           value='D.modular={}'))
        where.extend(
            base_db.filter(kwargs.get("modular"),
                           value='D.modular={}'))
        where.extend(
            base_db.filter(kwargs.get("type"),
                           value='D.type={}'))
        limit = base_db.paginator(page, page_size)
        _sql = """
        select
            %s
        from equipments A
        left join customer B
        on A.customer_id=B.id
        left join electricity_user C
        on A.electricity_user_id=C.id
        left join equipment_data D
        on A.equipment_data_id=D.id
        """
        count_select = """
            COUNT(A.id) as `count`
        """
        count_sql, _ = base_db.sql_splice(_sql,
                                          count_select,
                                          where=where)
        select = """
        A.id,
        A.customer_id,
        B.name customer_name,
        A.electricity_user_id,
        C.number,
        A.name,
        D.modular,
        D.type,
        D.manufacturer,
        D.model,
        D.capacity
        """
        sql, count = base_db.sql_splice(_sql,
                                        select,
                                        where=where,
                                        count_sql=count_sql,
                                        limit=limit,
                                        limit_idx="A.id")
        df = base_db.query_df(sql)
        return df.to_dict(orient="records"), count

    def equipment_tree(self, **kwargs):
        id = kwargs.get("id", 0)
        where = ["1=1"]
        if not kwargs.get("customer_id"):
            raise ValidationError("缺少关键参数: customer_id")
        if not kwargs.get("electricity_user_id"):
            raise ValidationError("缺少关键参数: electricity_user_id")
        where.extend(
            base_db.filter(kwargs.get("customer_id"),
                           value='A.customer_id={}'))
        where.extend(
            base_db.filter(kwargs.get("electricity_user_id"),
                           value='A.electricity_user_id={}'))
        where.extend(base_db.filter(id, value='A.id={}'))
        _sql = """
        select
            %s
        from equipments A
        """
        select = """
        A.id, A.name, A.pid, A.customer_id,
        A.electricity_user_id, A.equipment_data_id
        """
        sql, _ = base_db.sql_splice(_sql,
                                    select=select,
                                    where=where)
        rst = base_db.query(sql)
        tree, _ = base_db.structure_tree(rst,
                                         id,
                                         index="id",
                                         pid="pid")
        return tree

    def get_equipment_data(self, id):
        eq = Equipment.objects.filter(id=id).first()
        if not eq:
            return {}
        return eq.equipment_data

    @atomic()
    def create_or_update_equipment_data(self, id, eq_data):
        data = eq_data.pop("data", None)
        eq = Equipment.objects.filter(id=id).first()
        if not eq:
            raise ValidationError("资产不存在")
        eq.__dict__.update(**eq_data)
        eq.save()
        if data:

            if eq.equipment_data:
                eq.equipment_data.__dict__.update(**data)
                eq.equipment_data.save()
            else:
                neq_dt = EquipmentData.objects.create(**data)
                eq.equipment_data = neq_dt
                eq.save()
            return eq.equipment_data
        else:
            return None

    @atomic()
    def delete_equipment_data(self, id):
        sub_eq = Equipment.objects.filter(pid=id)
        if sub_eq:
            raise ValidationError("请先删除子设备后，再进行操作")
        eq = Equipment.objects.filter(id=id).first()
        if eq.equipment_data:
            eq.equipment_data.delete()
        eq.delete()


class EquipmentInfo(object):
    def get_equipment_info(self, equipment_id):
        return Equipment.objects.filter(deleted=False, id=equipment_id).first()


equipment_queryset = EquipmentQuerySet()
equipment_info = EquipmentInfo()
