import base64
import datetime
import gzip
import json
from io import BytesIO
from typing import Callable, Optional, Dict, Any, Union, List

import requests
from requests.models import Response

from .__version__ import __version__


def py3_b64encode(input_str: str) -> str:
    input_bytes = input_str.encode('utf8')
    encoded = base64.b64encode(input_bytes)
    return encoded.decode('utf8')


def py3_gzip_encode(data: str) -> bytes:
    out = BytesIO()
    f = gzip.GzipFile(mode='wb', fileobj=out)
    f.write(data.encode('utf8'))
    f.close()
    return out.getvalue()


def py3_response_decode(data: bytes) -> str:
    return data.decode('utf8')


class ServiceUnavailable(Exception):
    """ Exception used when status code is 503"""
    pass


def authenticate(func: Callable) -> Callable:
    def authenticate_and_call(*args) -> Callable:
        _self = args[0]
        if _self.should_fetch_token():
            if not _self.is_access_token_valid():
                _self.access_token = _self.fetch_access_token()
            return func(*args)
        else:
            return func(*args)

    return authenticate_and_call


def backoff(func: Callable) -> Callable:
    def _backoff(*args) -> Callable:
        """ *args here are the parameters of `func`, the method on which the decorator is
            the first element is always `self`, using it, we can get the config and use
            the retry function with parameters that are configurable
        """
        if len(args) > 0:
            _self = args[0]
            if _self._backoff_config is not None:
                try:
                    from tenacity import retry, wait_exponential, wait_random, stop_after_attempt, \
                        retry_if_exception_type, before, before_nothing

                    max_attempts = _self._backoff_config.number_of_attempts
                    initial_delay = _self._backoff_config.initial_delay_in_seconds
                    if _self._backoff_config.on_retry is not None:
                        on_retry = _self._backoff_config.on_retry
                    else:
                        on_retry = before_nothing

                    def call():
                        return func(*args)

                    return retry(stop=stop_after_attempt(max_attempts),
                                 wait=wait_exponential(multiplier=initial_delay) + wait_random(min=0.01, max=0.05),
                                 retry=retry_if_exception_type(exception_types=ServiceUnavailable), before=on_retry,
                                 reraise=True)(call)()
                except ImportError:
                    raise ImportError('The "tenacity" package is required to use this feature')
            else:
                return func(*args)
        else:
            return func(*args)

    return _backoff


class APIManager(object):
    def __init__(self, client_id: Optional[str], client_secret: Optional[str], hostname: str,
                 compression_enabled: bool,
                 backoff_config: Optional[Dict[str, Any]],
                 token_override: Optional[Dict[str, Any]]):
        """ Initializes the API Manager which is responsible for authenticating every request.

        :param client_id: the client id generated by mnubo
        :param client_secret: the client secret generated by mnubo
        :param hostname: the hostname to send the requests (sandbox or production)
        :param compression_enabled: if True, enable compression in the HTTP requests (default: True)
        """

        if not token_override:
            if not client_id:
                raise ValueError("client_id cannot be null or empty.")

            if not client_secret:
                raise ValueError("client_secret cannot be null or empty.")
            self.__client_id = client_id
            self.__client_secret = client_secret
            self.__static_token = None
        else:
            self.__client_id = None
            self.__client_secret = None
            self.__static_token = token_override

        try:
            requests.head(hostname)
        except requests.exceptions.ConnectionError:
            raise ValueError(f"Host at {hostname} is not reachable")
        self.__hostname = hostname

        self.__hybridb64 = py3_b64encode
        self.__gzip_encode = py3_gzip_encode
        self.__response_decode = py3_response_decode

        self.__session = requests.Session()
        self.compression_enabled = compression_enabled
        self._backoff_config = backoff_config

        if not token_override:
            self.access_token = self.fetch_access_token()

    def fetch_access_token(self) -> Dict[str, Any]:
        """ Requests the access token necessary to communicate with the smartobjects platform
        """

        requested_at = datetime.datetime.now()

        r = self.__session.post(self.get_auth_url(), headers=self.get_token_authorization_header())
        r.raise_for_status()
        json_response = r.json()

        return {
            'access_token': json_response['access_token'],
            'expires_in': datetime.timedelta(0, json_response['expires_in']),
            'requested_at': requested_at
        }

    def is_access_token_valid(self) -> bool:
        """ Validates if the token is still valid

        :return: True of the token is still valid, False if it is expired
        """

        return (self.access_token['requested_at'] + self.access_token['expires_in']) > datetime.datetime.now()

    def should_fetch_token(self) -> bool:
        """ Check if the manager should fetch a token or not
        :return: True if not static_token is available, False otherwise
        """
        return not self.__static_token

    def get_token_authorization_header(self) -> Dict[str, str]:
        """ Generates the authorization header used while requesting an access token
        """

        encoded = self.__hybridb64(f"{self.__client_id}:{self.__client_secret}")
        return {'content-type': 'application/x-www-form-urlencoded', 'Authorization': f"Basic {encoded}"}

    def get_authorization_header(self) -> Dict[str, str]:
        """ Generates the authorization header used to access resources via smartobjects's API
        """
        if not self.__static_token:
            token = self.access_token['access_token']
        else:
            token = self.__static_token
        return {
            'content-type': 'application/json',
            'Authorization': 'Bearer ' + token,
            'X-MNUBO-SDK': 'python/' + __version__
        }

    def get_api_url(self) -> str:
        """ Generates the general API url
        """

        return self.__hostname + '/api/v3/'

    def get_auth_url(self) -> str:
        """ Generates the url to fetch the access token
        """

        return self.__hostname + '/oauth/token?grant_type=client_credentials&scope=ALL'

    def validate_response(self, response: Response):
        """ Raises a ValueError instead of a HTTPError in case of a 400 or 409

        This allows easier development and consistency with client-side checks
        """
        if response.status_code in (400, 409):
            raise ValueError(response.content)
        if response.status_code == 503:
            raise ServiceUnavailable
        response.raise_for_status()

    @authenticate
    @backoff
    def get(self, route: str, params: Optional[Union[Dict[str, Any], str]] = {}) -> Response:
        """ Build and send a get request authenticated

        :param route: str to be included in the HTTP request
        :param params: (optional) additional parameters for the request string
        """

        url = self.get_api_url() + route
        headers = self.get_authorization_header()

        response = self.__session.get(url, params=params, headers=headers)
        self.validate_response(response)

        return response

    @authenticate
    @backoff
    def post(self, route: str, body: Optional[Union[dict, Dict[str, Any], str, List[Any]]] = {}) -> Response:
        """ Build and send a post request authenticated

        :param route: resource path (not including the API root)
        :param body: JSON body to be included in the HTTP request
        """

        url = self.get_api_url() + route
        headers = self.get_authorization_header()

        if self.compression_enabled:
            headers.update({"content-encoding": "gzip"})
            encoded = self.__gzip_encode(json.dumps(body))
            response = self.__session.post(url, data=encoded, headers=headers)
        else:
            response = self.__session.post(url, json=body, headers=headers)

        self.validate_response(response)

        return response

    @authenticate
    @backoff
    def put(self, route: str, body: Union[Dict[str, Any], List[Dict[str, str]]] = {}) -> Response:
        """ Build and send an authenticated put request

        :param route: resource path (not including the API root)
        :param body: JSON body to be included in the HTTP request
        """

        url = self.get_api_url() + route
        headers = self.get_authorization_header()

        if self.compression_enabled:
            headers.update({"content-encoding": "gzip"})
            encoded = self.__gzip_encode(json.dumps(body))
            response = self.__session.put(url, data=encoded, headers=headers)
        else:
            response = self.__session.put(url, json=body, headers=headers)

        self.validate_response(response)

        return response

    @authenticate
    @backoff
    def delete(self, route: str) -> Response:
        """ Build and send a delete request authenticated

        :param route: which resource to access via the REST API
        """

        url = self.get_api_url() + route
        headers = self.get_authorization_header()

        response = self.__session.delete(url, headers=headers)
        self.validate_response(response)

        return response
