from typing import Mapping

import tensorflow as tf
from pyjackson.errors import SerializationError

from algolink.core.analyzer.base import CanIsAMustHookMixin
from algolink.core.analyzer.dataset import DatasetAnalyzer, DatasetHook
from algolink.core.objects.dataset_type import DatasetType, DictDatasetType
from algolink.core.objects.requirements import InstallableRequirement, Requirements


class FeedDictDatasetType(DictDatasetType):
    """
    :class:`~algolink.core.objects.DatasetType` implementation for tensorflow feed dict argument
    """

    @classmethod
    def from_feed_dict(cls, feed_dict):
        """
        Factory method to create :class:`FeedDictDatasetType` from feed dict

        :param feed_dict: feed dict
        :return: :class:`FeedDictDatasetType` instance
        """
        types = {}
        for k, v in feed_dict.items():
            types[cls.get_key(k)] = DatasetAnalyzer.analyze(v)
        return FeedDictDatasetType(types)

    def serialize(self, instance: dict):
        self._check_type(instance, dict, SerializationError)
        try:
            items = {self.get_key(k): v for k, v in instance.items()}
        except ValueError as e:
            raise SerializationError(e)
        return super().serialize(self, items)

    @staticmethod
    def get_key(k):
        if isinstance(k, tf.Tensor):
            return k.name
        elif isinstance(k, str):
            return k
        else:
            raise ValueError(f'Unknown key type {type(k).__name__} for key {k} in feed_dict')

    @DictDatasetType.requirements.getter
    def requirements(self) -> Requirements:
        return Requirements([InstallableRequirement.from_module(tf)]) + super().requirements.fget(self)


class FeedDictHook(CanIsAMustHookMixin, DatasetHook):
    """
    DatasetHook for tensorflow feed dict
    """

    def must_process(self, obj) -> bool:
        """
        :param obj: obj to check
        :return: `True` if obj is mapping and any of it's keys are tf.Tensor instance
        """
        is_mapping = isinstance(obj, Mapping)
        return is_mapping and any(isinstance(k, tf.Tensor) for k in obj.keys())

    def process(self, obj, **kwargs) -> DatasetType:
        """
        :param obj: obj to process
        :return: :class:`FeedDictDatasetType` instance
        """
        return FeedDictDatasetType.from_feed_dict(obj)
