from etl.utils import getLogger
from etl.settings import DATABASE
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from etl.db.transformation import TRANSFORMATION_MAP

logger = getLogger(__name__)

SLAVE_MYSQL_URL = "mysql+mysqlconnector://{}:{}@{}:{}/{}".format(
    DATABASE.SLAVE_MYSQL_USER, DATABASE.SLAVE_MYSQL_PASS,
    DATABASE.SLAVE_MYSQL_HOST, DATABASE.SLAVE_MYSQL_PORT,
    DATABASE.SLAVE_MYSQL_DB)

TRANSFORMATION_MYSQL_URL = "mysql+mysqlconnector://{}:{}@{}:{}/{}".format(
    DATABASE.TRANSFORMATION_MYSQL_USER, DATABASE.TRANSFORMATION_MYSQL_PASS,
    DATABASE.TRANSFORMATION_MYSQL_HOST, DATABASE.TRANSFORMATION_MYSQL_PORT,
    DATABASE.TRANSFORMATION_MYSQL_DB)

slave_engine = create_engine(SLAVE_MYSQL_URL,
                       pool_size=5,
                       max_overflow=10,
                       pool_timeout=30,
                       pool_pre_ping=True)

trans_engine = create_engine(TRANSFORMATION_MYSQL_URL,
                       pool_size=5,
                       max_overflow=10,
                       pool_timeout=30,
                       pool_pre_ping=True)

slaveDBSession = sessionmaker(bind=slave_engine)
transDBSession = sessionmaker(bind=slave_engine)

slave_session = slaveDBSession()
trans_session = transDBSession()


@contextmanager
def session_maker(session=slave_session):
    try:
        yield session
        session.commit()
    except Exception as e:
        logger.debug(e)
        session.rollback()
        raise
    finally:
        session.close()


def update(table, json_data, index):
    with session_maker() as db_session:
        has_dt = db_session.query(table).get({index: json_data[index]})
        if has_dt:
            index_dt = json_data.pop(index)
            db_session.query(table).filter_by(**{
                index: index_dt
            }).update(json_data)
        else:
            insert(table, json_data, index)



def insert(table, json_data, index):
    new_row = table(**json_data)
    with session_maker() as db_session:
        db_session.add(new_row)


def delete(table, json_data, index):
    with session_maker() as db_session:
        has_dt = db_session.query(table).get({index: json_data[index]})
        if has_dt:
            db_session.query(table).filter_by(**{
                index: json_data[index]
            }).delete()
        else:
            pass


db_map = {
    "update": update,
    "insert": insert,
    "delete": delete,
}


def to_slave_sql(table, tp, json_data, old, index, ts):
    idx = json_data[index]
    logger.debug("开始{}数据表{}中{}为{}的数据行".format(tp, table, index, idx))
    if db_map.get(tp):
        db_map[tp](table, json_data, index)
        logger.debug("完成{}数据表{}中{}为{}的数据行".format(tp, table, index, idx))
    else:
        pass


def to_trans_sql(table, tp, json_data, old, index, ts):
    ...


def json_to_sql(table, tp, json_data, old, index, ts):
    to_slave_sql(table, tp, json_data, old, index, ts)
    to_trans_sql(table, tp, json_data, old, index, ts)


class TransFormationDB(object):

    def __init__(self, table, tp, json_data, old, idx):
        self.table = table
        self.tp = tp
        self.json_data = json_data
        self.old = old
        self.idx = idx

    def update(self):
        with session_maker(trans_session) as db_session:
            has_dt = db_session.query(self.table).get({self.idx: self.json_data[self.idx]})
            if has_dt:
                index_dt = self.json_data.pop(self.idx)
                db_session.query(self.table).filter_by(**{
                    self.idx: index_dt
                }).update(self.json_data)
            else:
                insert(self.table, self.json_data, self.idx)

    def insert(self):
        new_row = self.table(**self.json_data)
        with session_maker() as db_session:
            db_session.add(new_row)

    def delete(self):
        ...

    def get_trans_json(self, table, tp, json_data, old, idx):
        if table in ["io_customer", "cm_contract"]:
            idx = "customer_id"
        table_json = TRANSFORMATION_MAP[table](idx)
        for key in old.keys():
            getattr(table_json, key)(json_data[key])
        tables = table_json.tables
        for k, v in tables.items():
            getattr(self, tp)
