import jwt
import datetime

from django.utils.encoding import smart_text

from rest_framework_jwt.authentication import BaseJSONWebTokenAuthentication
from rest_framework_jwt.settings import api_settings
from rest_framework import exceptions
from rest_framework import status
from rest_framework.response import Response
from rest_framework.authentication import get_authorization_header
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.generics import GenericAPIView

from common.accounts.models import Accounts, UserEnterprise, Enterprise, Users
from common.utils.exceptions import (AuthenticationError,
                                     OpenApiServiceHandleException,
                                     ServiceHandleException, NoPermissionError)
from common.utils import perms_conf
from common.utils.pagination import CustomPageNumberPagination
from common.utils.return_structure import general_message
from common.utils.service_code import (INFO_CODE, INIT_ERROR, PARAMETER_ERROR,
                                       AUTHENTICATION_ERROR,
                                       AUTHENTICATION_EXPIRE,
                                       OPEN_API_APP_NOTFOUND)
from console.customer.models import Customer, CustomerAdmin
from openapi.individual_soldier.models import OpenApiToken
from common.utils.gettimes import GetRunTime

from common.accounts.models import UserRoles

jwt_get_username_from_payload = api_settings.JWT_PAYLOAD_GET_USERNAME_HANDLER
jwt_decode_handler = api_settings.JWT_DECODE_HANDLER
jwt_response_payload_handler = api_settings.JWT_RESPONSE_PAYLOAD_HANDLER


class JSONWebTokenAuthentication(BaseJSONWebTokenAuthentication):
    """
    Clients should authenticate by passing the token key in the "Authorization"
    HTTP header, prepended with the string specified in the setting
    `JWT_AUTH_HEADER_PREFIX`. For example:

        Authorization: JWT eyJhbGciOiAiSFMyNTYiLCAidHlwIj
    """
    www_authenticate_realm = 'api'

    def get_jwt_value(self, request):
        auth = get_authorization_header(request).split()
        auth_header_prefix = api_settings.JWT_AUTH_HEADER_PREFIX.lower()

        if not auth:
            if api_settings.JWT_AUTH_COOKIE:
                return request.COOKIES.get(api_settings.JWT_AUTH_COOKIE)
            return None

        if smart_text(auth[0].lower()) != auth_header_prefix:
            return None

        if len(auth) == 1:
            msg = '请求头不合法，未提供认证信息'
            raise exceptions.AuthenticationFailed(msg)
        elif len(auth) > 2:
            msg = "请求头不合法"
            raise exceptions.AuthenticationFailed(msg)
        return auth[1]

    def authenticate(self, request):
        """
        Returns a two-tuple of `User` and token if a valid signature has been
        supplied using JWT-based authentication.  Otherwise returns `None`.
        """
        # update request authentication info

        jwt_value = self.get_jwt_value(request)
        if jwt_value is None:
            msg = '未提供验证信息'
            raise AuthenticationError(msg)
        try:
            payload = jwt_decode_handler(jwt_value)
        except jwt.ExpiredSignature:
            msg = '认证信息已过期'
            raise AuthenticationError(msg)
        except jwt.DecodeError:
            msg = '认证信息错误'
            raise AuthenticationError(msg)
        except jwt.InvalidTokenError:
            msg = '认证信息错误,请求Token不合法'
            raise AuthenticationError(msg)

        user = self.authenticate_credentials(payload)
        return user, jwt_value

    def authenticate_credentials(self, payload):
        """
        Returns an active user that matches the payload's user id and email.
        """
        username = jwt_get_username_from_payload(payload)
        if not username:
            msg = '认证信息不合法.'
            # raise exceptions.AuthenticationFailed(msg)
            raise AuthenticationError(msg)
        try:
            user = Accounts.objects.get(username=username)
        except Accounts.DoesNotExist:
            msg = '签名不合法.'
            raise AuthenticationError(msg)
        if not user.is_active:
            msg = '用户身份未激活.'
            raise AuthenticationError(msg)

        return user


class AllowAnyApiView(GenericAPIView):
    """
    该API不需要通过任何认证
    """
    permission_classes = (AllowAny,)
    authentication_classes = ()

    def __init__(self, *args, **kwargs):
        super(AllowAnyApiView, self).__init__(*args, **kwargs)
        self.user = None

    def initial(self, request, *args, **kwargs):
        super(AllowAnyApiView, self).initial(request, *args, **kwargs)
        self.user = request.user


class BaseApiView(GenericAPIView):
    permission_classes = (IsAuthenticated,)
    authentication_classes = (JSONWebTokenAuthentication,)
    pagination_class = CustomPageNumberPagination

    def __init__(self, *args, **kwargs):
        super(BaseApiView, self).__init__(*args, **kwargs)
        self.user = None
        self.account = None
        self.is_manager = False
        self.is_customer = False
        self.customer = None
        self.enterprise = None
        self.enterprises = None
        self.enterprise_id = None
        self.user_enterprise = None
        self.enterprise_ids = None
        self.service_ent_id = None
        self.enterprise_entids = None
        self.perms = None
        self.s_cst_id = None
        self.role = None

    def initial(self, request, *args, **kwargs):
        super(BaseApiView, self).initial(request, *args, **kwargs)
        self.account = request.user
        self.user = Users.objects.filter(account_id=self.account.id).first()
        user_role = UserRoles.objects.filter(user=self.user).first()
        self.role = user_role.role if user_role else None
        user_enterprises = UserEnterprise.objects.filter(user_id=self.user.id)
        self.enterprise_entids = [
            str(i)
            for i in user_enterprises.values_list("enterprise_id", flat=True)
        ]
        enterprises = Enterprise.objects.filter(
            enterprise_id__in=self.enterprise_entids).order_by("id")
        self.enterprises = enterprises
        self.enterprise_ids = [
            str(i) for i in enterprises.values_list("id", flat=True)
        ]
        self.s_cst_id = self.request.COOKIES.get('s_cst_id')
        if not self.enterprise_id:
            self.enterprise_id = self.request.COOKIES.get('enterprise_id')
        if self.account.account_type == "manager":
            self.is_manager = True
            if self.s_cst_id:
                customer = Customer.objects.filter(id=self.s_cst_id).first()
                if not customer:
                    raise ServiceHandleException(msg_show="指定的客户不存在",
                                                 code=104014,
                                                 msg="not found customer")
                self.service_ent_id = customer.service_enterprise.id
                if str(customer.service_enterprise.id) not in self.enterprise_ids:
                    raise ServiceHandleException(msg_show="没有该客户平台的权限",
                                                 code=104008,
                                                 msg="not found customer")
                self.customer = customer
                self.enterprise = customer.enterprise
                self.user_enterprise = UserEnterprise.objects.filter(
                    enterprise_id=self.enterprise.enterprise_id)
                self.enterprise_entids = [str(self.enterprise.enterprise_id)]
                self.enterprise_id = self.enterprise.enterprise_id
                self.enterprise_ids = [str(self.enterprise.id)]
                self.user.nickname = customer.name
                self.user.account.account_type = "customer"
                self.account.account_type = "customer"
                self.is_customer = True
                self.is_manager = False

        elif self.account.account_type == "customer":
            self.is_customer = True
            customer_admin = CustomerAdmin.objects.filter(user_id=self.user.id)
            if customer_admin:
                customer_ids = customer_admin.values_list("customer_id",
                                                          flat=True)
                self.customer = Customer.objects.filter(
                    id__in=customer_ids).first()
                self.service_ent_id = self.customer.service_enterprise.id
        if not self.enterprise_id:
            self.user_enterprise = user_enterprises.first()
            self.enterprise = self.user_enterprise.enterprise
            self.enterprise_id = self.enterprise.enterprise_id
        elif not self.s_cst_id:
            self.user_enterprise = user_enterprises.filter(enterprise_id=self.enterprise_id).first()
            self.enterprise = self.user_enterprise.enterprise

        from common.accounts.services import user_service
        self.perms = user_service.get_user_perms(self.user.id)
        self.check_perms(request, *args, **kwargs)
        self.operation = self.get_operation(request, *args, **kwargs)

    def create(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        data = general_message(code=INFO_CODE,
                               msg="success",
                               msg_show="创建成功",
                               bean=serializer.data)
        return Response(data, status=status.HTTP_201_CREATED)

    def get_list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset().order_by(
            "-id", "-created_time"))
        if request.GET.get("page"):
            page = self.paginate_queryset(queryset)
            serializer = self.get_serializer(page, many=True)
            return self.get_paginated_response(serializer.data)
        serializer = self.get_serializer(instance=queryset, many=True)
        data = general_message(msg='success',
                               msg_show='获取列表成功',
                               list=serializer.data)
        return Response(data, status=status.HTTP_200_OK)

    def get_all(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset().order_by(
            "-id", "-created_time"))
        print(queryset.query)
        serializer = self.get_serializer(instance=queryset, many=True)
        data = general_message(msg='success',
                               msg_show='获取列表成功',
                               list=serializer.data)
        return Response(data, status=status.HTTP_200_OK)

    def get_info(self, request, *args, **kwargs):
        pk = kwargs.get("id")
        obj = self.get_queryset().filter(id=pk).first()
        serializer = self.get_serializer(instance=obj)
        return Response(
            general_message(msg="success",
                            msg_show="数据请求成功",
                            bean=serializer.data))

    def has_perms(self, code):
        if isinstance(code, int):
            return code in self.perms
        elif isinstance(code, list):
            return set(code) <= set(self.perms)

    def check_perms(self, request, *args, **kwargs):
        perms = []
        message = kwargs.get("__message")
        if not message:
            message = getattr(perms_conf, self.__class__.__name__,
                              {}).get("__message")
        if message:
            rq_method_dt = message.get(
                request.META.get("REQUEST_METHOD").lower())
            if rq_method_dt:
                perms = rq_method_dt.get("perms", [])
            if perms and not set(perms) <= set(self.perms):
                raise NoPermissionError

    def get_operation(self, request, *args, **kwargs):
        operation_map = {
            "get": "查看",
            "post": "创建",
            "put": "修改",
            "delete": "删除",
        }
        message = kwargs.get("__message")
        if message:
            rq_method = request.META.get("REQUEST_METHOD").lower()
            operation = operation_map[rq_method]
            rq_method_dt = message.get(
                request.META.get("REQUEST_METHOD").lower())
            if rq_method_dt:
                operation = rq_method_dt.get("operation",
                                             operation_map[rq_method])
            return operation


# 管理者基类
class ManagerApiView(BaseApiView):
    def __init__(self, *args, **kwargs):
        super(ManagerApiView, self).__init__(*args, **kwargs)

    def initial(self, request, *args, **kwargs):
        super(ManagerApiView, self).initial(request, *args, **kwargs)
        if not self.is_manager:
            raise ServiceHandleException(msg="",
                                         msg_show="该用户未加入到管理者企业",
                                         code=INIT_ERROR)
        if not self.enterprise:
            raise ServiceHandleException(msg="no found manager enterprise",
                                         msg_show="该用户未加入到企业",
                                         code=INIT_ERROR)


# 管理者用户基类
class ManagerUserApiView(ManagerApiView):
    def __init__(self, *args, **kwargs):
        super(ManagerUserApiView, self).__init__(*args, **kwargs)
        self.tar_user_id = None
        self.tar_user = None

    def check_user_in_enterprise(self, user_id):
        return bool(
            UserEnterprise.objects.filter(user_id=user_id,
                                          enterprise_id=self.enterprise_id))

    def initial(self, request, *args, **kwargs):
        super(ManagerUserApiView, self).initial(request, *args, **kwargs)
        if not self.tar_user_id:
            self.tar_user_id = kwargs.get("user_id")
        if self.tar_user_id:
            if not self.check_user_in_enterprise(self.tar_user_id):
                raise NoPermissionError(msg_show="没有操作对象的权限")
            self.tar_user = Users.objects.filter(id=self.tar_user_id).first()
        if not self.tar_user:
            raise ServiceHandleException(msg="no found user",
                                         msg_show="未找到操作对象",
                                         code=PARAMETER_ERROR)


class OpenApiBaseView(GenericAPIView):
    permission_classes = (AllowAny,)
    authentication_classes = ()

    def __init__(self, *args, **kwargs):
        super(OpenApiBaseView, self).__init__(*args, **kwargs)
        self.applications = None

    def initial(self, request, *args, **kwargs):
        super(OpenApiBaseView, self).initial(request, *args, **kwargs)
        app_token = self.request.COOKIES.get('app_token')
        app_id = self.request.COOKIES.get('app_id')
        open_api_token = OpenApiToken.objects.filter(app_id=app_id).first()
        if not open_api_token:
            raise OpenApiServiceHandleException(msg="no found app",
                                                msg_show="缺少app id",
                                                code=OPEN_API_APP_NOTFOUND)
        if app_token != open_api_token.token:
            raise OpenApiServiceHandleException(msg="app token error",
                                                msg_show="app token 错误",
                                                code=AUTHENTICATION_ERROR)
        if open_api_token.expired:
            if datetime.datetime.now() > open_api_token.expired:
                raise OpenApiServiceHandleException(msg="app token expire",
                                                    msg_show="app token 已过期",
                                                    code=AUTHENTICATION_EXPIRE)
        self.applications = open_api_token.app_range.split(",")
