import os
import sys
from django.core.management import BaseCommand
from django.db import connection
from django.db import models
from django.conf import settings

ROOT_URL = settings.BASE_DIR

SQL_URL = os.path.join(ROOT_URL, 'sql')


class InItDBLog(models.Model):
    id = models.AutoField(primary_key=True)
    name = models.CharField(max_length=32, null=False)
    sql_num = models.IntegerField(null=False)
    error_msg = models.CharField(max_length=512, null=True)
    run_time = models.DateTimeField(auto_created=True, null=False)

    class Meta:
        db_table = "init_db_log"


class Command(BaseCommand):
    help = '初始化数据库'

    def get_sql(self, sql_file):
        with open(os.path.join(SQL_URL, sql_file), "r") as f:
            sql_list = f.read().split(';')[:-1]
        return sql_list

    def get_sql_file_list(self):
        sql_file = []
        for _, _, files in os.walk(SQL_URL):
            sql_file.extend(files)
            sql_file.sort()
        return sql_file

    def handle(self, *args, **options):
        success_sql_list = []
        sql_map = {}
        sql_file_list = self.get_sql_file_list()
        sql_file_list.remove('log.sql')
        if 'README.md' in sql_file_list:
            sql_file_list.remove('README.md')

        with connection.cursor() as c:
            sql_list = self.get_sql('log.sql')
            for sql in sql_list:
                c.execute(sql)
        try:
            logs = InItDBLog.objects.all()
        except Exception as e:
            self.stdout.write("\033[31m[ERROR] %s\033[0m" % e)
            sys.exit(1)
        for i in logs:
            sql_map.update({i.name: i.sql_num})
        run_sql_file_list = [i.name for i in logs]
        run_sql_file_list = sorted(run_sql_file_list,
                                   key=lambda x: int(x.split(".")[0]))
        if run_sql_file_list:
            run_sql_file_list.pop()
        need_run_sql_file_list = list(
            set(run_sql_file_list) ^ set(sql_file_list))
        need_run_sql_file_list = sorted(need_run_sql_file_list,
                                        key=lambda x: int(x.split(".")[0]))

        for sql_file in need_run_sql_file_list:
            sql_num = 0
            to_run_sql_file_number = sql_map.get(sql_file, 0)
            sql_list = self.get_sql(sql_file)
            if not sql_list:
                self.stdout.write("\033[31m[ERROR] %s是空文件\033[0m" % sql_file)
                sys.exit(1)
            cursor = connection.cursor()
            for sql in sql_list:
                sql_num += 1
                if sql_num > to_run_sql_file_number:
                    try:
                        cursor.execute(sql)
                    except Exception as e:
                        error_msg = "文件%s第%s段sql %s" % (sql_file, sql_num, e)
                        InItDBLog.objects.filter(name=sql_file).delete()
                        success_sql_list.append(
                            InItDBLog(name=sql_file,
                                      sql_num=sql_num - 1,
                                      error_msg=error_msg))
                        InItDBLog.objects.bulk_create(success_sql_list)
                        self.stdout.write("\033[31m[ERROR] %s\033[0m" %
                                          error_msg)
                        sys.exit(1)
            InItDBLog.objects.filter(name=sql_file).delete()
            success_sql_list.append(InItDBLog(name=sql_file, sql_num=sql_num))
        InItDBLog.objects.bulk_create(success_sql_list)
        self.stdout.write("\033[32m数据库初始化成功\033[0m")
