# coding=utf-8
'''
Twisted Reactor时间戳TCP服务器
'''
import os
import sys
from twisted.internet import reactor, task
from twisted.internet.endpoints import TCP4ServerEndpoint
from twisted.internet.protocol import Factory, Protocol
from twisted.enterprise import adbapi
from pymysql import cursors

base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)

from exceptions import ParseHeaderError
from utils import getLoggerForlbzy
from adapters.acrel_pz96l import AcrelPZ96L

logger = getLoggerForlbzy(__name__)

# 生产环境数据库
dbparmas = {
    'host': 'cdb-noc7c8is.bj.tencentcdb.com',
    'user': 'root',
    'password': 'Ubuntu123$',
    'database': 'monitor',
    "port": 10152,
    'charset': 'utf8',
    'cursorclass': cursors.DictCursor
}
payment_dbparam = {
    'host': 'cdb-noc7c8is.bj.tencentcdb.com',
    'user': 'root',
    'password': 'Ubuntu123$',
    'database': 'monitor',
    "port": 10152,
    'charset': 'utf8',
    'cursorclass': cursors.DictCursor
}


class MysqlDBEngine(object):

    def __init__(self, **dbconfig):
        self.dbpool = adbapi.ConnectionPool('pymysql', **dbconfig)

    def _createDB(self, cursor):
        pass

    def _sql(self, cursor, sql):
        cursor.execute(sql)

    def sql(self, sql):
        return self.dbpool.runInteraction(self._sql, sql)

    def fetch(self, sql):
        return self.dbpool.runQuery(sql)

    def queryAll(self):
        pass


class TSServProtocol(Protocol):
    # MAX_LENGTH = 102400
    db = MysqlDBEngine(**dbparmas)
    payment_db = MysqlDBEngine(**payment_dbparam)

    def __init__(self, factory):
        self.factory = factory

    def connectionMade(self):
        '''
        当客户端连接的时候会执行该方法
        :return:
        '''
        self.transport.setTcpNoDelay(True)
        self.factory.numProtocols = self.factory.numProtocols + 1
        self.client_host = self.transport.getPeer().host
        self.client_port = self.transport.getPeer().port
        self.clientID = ':'.join([str(self.client_host), str(self.client_port)])
        logger.info("...来自<{}>的链接:".format(self.clientID))
        logger.info("当前连接数：%d" % self.factory.numProtocols)
        self.factory.clients.update({self.clientID: self})
        if self.clientID not in self.factory.clients_message.keys():
            self.factory.clients_message.update({self.clientID: ''.encode()})
            self.factory.not_enough.update({self.clientID: False})
            self.factory.headers.update({self.clientID: ''})
            self.factory.clients_content.update({self.clientID: {}})
            self.factory.clients_recv_length.update({self.clientID: 0})
            self.factory.model_list.update({self.clientID: []})
            self.factory.recharge_list.update({self.clientID: []})
            self.factory.last_balance_hour.update({self.clientID: {}})
            self.factory.adapter_type.update({self.clientID: ''})
            self.factory.loop_task_list.update({self.clientID: []})

    def connectionLost(self, reason):  # 执行反向操作
        logger.debug(reason)
        logger.debug("关闭连接：<{}>".format(self.clientID))
        self.factory.numProtocols = self.factory.numProtocols - 1
        logger.debug("当前连接数：{}".format(self.factory.numProtocols))

        self.factory.clients.pop(self.clientID)
        for t in self.factory.loop_task_list[self.clientID]:
            t.stop()

    def dataReceived(self, data):
        """
        接收到客户端的数据
        :param data:
        :return:
        """
        logger.info("来自客户端<{}:{}>': {}".format(self.client_host, self.client_port, data))
        adapter_class = AcrelPZ96L()
        try:
            adapter_class.initial(data, self, reactor)
            adapter_class.run_operations()
        except ParseHeaderError as e:
            logger.error(e)
            adapter_class.responses_ack(status=False, add_header=False)
        except Exception as e:
            logger.error(e)
            adapter_class.responses_ack(status=False, add_header=False)

    def loop_task(self, func, tm=60):
        t = task.LoopingCall(func)
        t.start(tm)
        self.factory.loop_task_list[self.clientID].append(t)

    def write(self, data):
        logger.debug("返回给客户端<{}:{}>的数据: {}".format(self.client_host,
                                                   self.client_port, data))
        self.transport.write(data)
        self.transport.doWrite()


class SpreaderFactory(Factory):
    def __init__(self):
        self.numProtocols = 0
        self.clients = {}  # protocol实例
        self.clients_message = {}  # 收到的消息内容
        self.not_enough = {}  # 是否满足长度
        self.headers = {}  # 报头
        self.clients_content = {}  # clientID对应building_id&gateway_id&order_id
        self.clients_recv_length = {}  # not used
        self.model_list = {}  # not used
        self.recharge_list = {}  # not used
        self.last_balance_hour = {}  # 上次余额小时
        self.adapter_type = {}  # 电表类型
        self.loop_task_list = {}  # loop_task列表

    def buildProtocol(self, addr):
        return TSServProtocol(self)


if __name__ == '__main__':
    endpoint = TCP4ServerEndpoint(reactor, 9000)
    endpoint.listen(SpreaderFactory())

    logger.info(u"....等待链接..")

    reactor.run()
