"""fix yelling at me error"""
from asyncio.proactor_events import _ProactorBasePipeTransport
from functools import wraps

from mpesasync.contracts import MpesaResponse


def silence_event_loop_closed(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        try:
            return func(self, *args, **kwargs)
        except RuntimeError as e:
            if str(e) != 'Event loop is closed':
                raise

    return wrapper


_ProactorBasePipeTransport.__del__ = silence_event_loop_closed(_ProactorBasePipeTransport.__del__)
"""fix yelling at me error end"""

from typing import Optional

from enum import Enum
import re
import base64
import os

from cryptography import x509
from cryptography.hazmat.primitives.asymmetric import padding

from mpesasync.types import *
from mpesasync.restlib import *


# region Types

class MpesaEnvironment(Enum):
    sandbox = "sandbox"
    production = "production"

    @classmethod
    def get_env(cls, envname: str):
        if envname == "production":
            return cls.production
        return cls.sandbox


class AuthorizationResponse(BaseModel):
    expires_in: str
    access_token: str


# endregion

class Mpesa(BaseModel):
    Environment: MpesaEnvironment
    AccessToken: Optional[str]
    ResultURL: Optional[str]
    QueueTimeOutURL: Optional[str]

    @property
    def base_url(self) -> str:
        if self.Environment is MpesaEnvironment.sandbox:
            return "https://sandbox.safaricom.co.ke"
        else:
            return "https://safaricom.co.ke"

    def endpoint(self) -> str:
        raise NotImplementedError

    async def authorize(self, consumer_key: str, consumer_secret: str):
        """
        Gives you a time bound access token to call allowed APIs.
        """
        from mpesasync.restlib import HttpClient
        params = [("grant_type", "client_credentials")]
        resp = await HttpClient.HttpGet(
            url="%s/oauth/v1/generate?grant_type=client_credentials" % self.base_url,
            params=params,
            auth=(consumer_key, consumer_secret)
        )
        if resp.error is None:
            resp.data = AuthorizationResponse.parse_obj(resp.data)
            self.AccessToken = resp.data.access_token
        else:
            raise Exception(f"Unable to authenticate to mpesa {resp.error}")

    @staticmethod
    def get_security_credential(mpesa_environment: MpesaEnvironment, initiator_password: str) -> str:
        """
        encrypt the API initiator password.
        M-Pesa Core authenticates a transaction by decrypting the security credentials.
        Security credentials are generated by encrypting the base64 encoded initiator password with M-Pesa’s public key, a X509 certificate.
        The algorithm for generating security credentials is as follows:
        Write the unencrypted password into a byte array.
        Encrypt the array with the M-Pesa public key certificate.
        Use the RSA algorithm, and use PKCS #1.5 padding (not OAEP), and add the result to the encrypted stream.
        Convert the resulting encrypted byte array into a string using base64 encoding.
        The resulting base64 encoded string is the security credential.
        """
        certificates_dir = os.path.dirname(__file__)
        if mpesa_environment.sandbox:
            certificate_path = os.path.join(certificates_dir, "certificates/ProductionCertificate.cer")
        else:
            certificate_path = os.path.join(certificates_dir, "certificates/SandboxCertificate.cer")

        with open(certificate_path, 'rb') as fp:
            file = fp.read()
            cert = x509.load_pem_x509_certificate(file)
            ciphertext = cert.public_key().encrypt(initiator_password.encode(), padding=padding.PKCS1v15())
            b64_cyphertext = base64.b64encode(ciphertext).decode()
            return b64_cyphertext

    @staticmethod
    def validate_phonenumber(phone_number: str):
        """
        The number should have the country code (254) without the plus sign.
        """
        pattern = re.compile("^(254)[17][0-9]{8}$")
        if pattern.findall(phone_number):
            return True
        raise ValueError(
            "Invalid phone number format. %s should be a Valid Safaricom "
            "Mobile Number that is M-Pesa registered in the format 2547XXXXXXXX" % phone_number)
