import inspect
import logging
import time

from .backend import ShieldAccessRequest
from .model import ConversationType
from .exception import AccessControlException
from .interceptor import intercept_methods, wrap_method, MethodIOCallback

_logger = logging.getLogger(__name__)
TRACE_LEVEL = 5

# not used, but kept in case we want to quickly test something
def _intercept_langchain(paig_plugin):
    # moved the imports here so that this becomes a runtime dependency
    from langchain.llms.openai import BaseOpenAI
    from langchain.chains.base import Chain

    intercept_methods(paig_plugin, BaseOpenAI, ['_generate', '_agenerate'], BaseLLMGenerateCallback)
    intercept_methods(paig_plugin, Chain, ['__call__'], ChainCallCallback)


class LangChainLLMInterceptorSetup:

    def __init__(self, **kwargs):
        self.filter_out_classes = ["LLM", "BaseLLM", "BaseLanguageModel", "Serializable", "RunnableSerializable",
                                   "BaseModel",
                                   "Representation", "ABC", "object",
                                   "Generic"] if "filter_out_classes" not in kwargs else kwargs["filter_out_classes"]
        self.filter_in_classes = [] if "filter_in_classes" not in kwargs else kwargs["filter_in_classes"]
        self.methods_to_intercept = ["_generate", "_agenerate", "_stream", "_astream",
                                     "_call"] if "methods_to_intercept" not in kwargs else kwargs[
            "methods_to_intercept"]
        self.list_of_methods_to_intercept = []

    def intercept_methods_for_class(self, cls):
        cls_to_methods = dict()  # key: class name, value: list of methods
        for method_name, method in inspect.getmembers(cls, callable):
            if method_name.startswith("__"):
                continue
            qual_name = method.__qualname__
            cls_name_of_method = qual_name.split(".")[0]
            if cls_name_of_method not in cls_to_methods:
                cls_to_methods[cls_name_of_method] = []
            cls_to_methods[cls_name_of_method].append((method_name, method))
            _logger.log(TRACE_LEVEL,
                f"method_name: {method_name}, qual_name: {qual_name}, cls_name_of_method: {cls_name_of_method}")

        # now we want to upwards in the class hierarchy and find the methods to intercept
        for cls_in_hierarchy in inspect.getmro(cls):
            if cls_in_hierarchy.__name__ in self.filter_out_classes:
                # skip because we want to filter this class out
                continue
            if self.filter_in_classes and cls_in_hierarchy.__name__ not in self.filter_in_classes:
                # skip because we have a filter in place and this class is not in the filter
                continue
            _logger.log(TRACE_LEVEL,
                f"base-class__name__: {cls_in_hierarchy.__name__}, base-class.__module__: {cls_in_hierarchy.__module__} of {cls.__name__}")
            if cls_in_hierarchy.__name__ in cls_to_methods:
                # _logger.debug(json.dumps(cls_to_methods[cls_in_hierarchy.__name__], indent=4))
                for m_name, method in cls_to_methods[cls_in_hierarchy.__name__]:
                    if m_name in self.methods_to_intercept:
                        to_intercept = (
                            cls_in_hierarchy.__module__, cls_in_hierarchy.__name__, m_name, cls_in_hierarchy, method)
                        if to_intercept not in self.list_of_methods_to_intercept:
                            self.list_of_methods_to_intercept.append(to_intercept)
                            _logger.log(TRACE_LEVEL, f"will intercept {to_intercept} for class {cls.__name__}")
            else:
                _logger.log(TRACE_LEVEL, f"No methods found for this base-class: {cls_in_hierarchy.__name__}")

    def find_all_methods_to_intercept(self):
        # import is hidden here so these dependencies will be loaded only when required
        import langchain.llms

        start_time = time.time()
        for type_str, get_class_method in langchain.llms.get_type_to_cls_dict().items():
            _logger.log(TRACE_LEVEL, f"Type: {type_str}")
            self.intercept_methods_for_class(get_class_method())
        end_time = time.time()
        _logger.debug(f"Time taken to intercept all methods: {end_time - start_time} seconds")
        _logger.log(TRACE_LEVEL, f"list_of_methods_to_intercept: {self.list_of_methods_to_intercept}")
        _logger.debug(f"total number of methods that will be intercepted is {len(self.list_of_methods_to_intercept)}")

    def setup_interceptors(self, paig_plugin):
        count = 0
        for module_name, class_name, method_name, cls, method in self.list_of_methods_to_intercept:
            if method_name in ["_generate", "_agenerate"]:
                wrap_method(paig_plugin, cls, method_name, method, BaseLLMGenerateCallback)
                count += 1
            elif method_name == "_call":
                wrap_method(paig_plugin, cls, method_name, method, LLMCallCallback)
                count += 1
            elif method_name in ["stream", "_astream"]:
                # TODO - need to add additional interceptors for other methods
                pass
        return count


class BaseLLMGenerateCallback(MethodIOCallback):
    """
    Callback class for generating Long-Lived Model (LLM) data with access control.

    This class extends MethodIOCallback and adds access control to the LLM data generation process.
    """

    def __init__(self, paig_plugin, cls, method):
        """
        Initialize the BaseLLMGenerateCallback instance.

        Args:
            paig_plugin: The base plugin for the callback.
            cls: The class to which the callback is applied.
            method: The method to which the callback is applied.
        """
        super().__init__(paig_plugin, cls, method)

    def init(self):
        pass

    def check_access(self, access_result):
        last_response_message = access_result.get_last_response_message()

        if not access_result.get_is_allowed():
            raise AccessControlException(last_response_message.get_response_text())

    def process_inputs(self, *args, **kwargs):
        """
        Process the input values from the method callback.

        Args:
            *args: Positional arguments passed to the method.
            **kwargs: Keyword arguments passed to the method.

        Returns:
            tuple: A tuple containing the updated args and kwargs.
        """

        access_request = ShieldAccessRequest(
            application_key=self.paig_plugin.get_application_key(),
            client_application_key=self.paig_plugin.get_client_application_key(),
            conversation_thread_id=self.paig_plugin.generate_conversation_thread_id(),
            request_id=self.paig_plugin.generate_request_id(),
            user_name=self.paig_plugin.get_current_user(),
            request_text=args[0],
            conversation_type=ConversationType.PROMPT
        )

        access_result = self.paig_plugin.get_shield_client().is_access_allowed(access_request)

        # Throw exception if access is denied
        self.check_access(access_result)

        # We need to prepare updated arguments
        response_messages = access_result.get_response_messages()
        updated_input_args = []
        for message in response_messages:
            updated_input_args.append(message.get_response_text())

        updated_args = (updated_input_args,) + args[1:]

        return updated_args, kwargs

    def process_output(self, output):
        """
        Process the output values from the method callback.

        Args:
            output: The output from the method callback.

        Returns:
            Any: The processed output.
        """
        updated_generations = []

        if output.generations:
            generations = output.generations

            for generation in generations:

                request_text = []
                for s_generation in generation:
                    request_text.append(s_generation.text)

                access_request = ShieldAccessRequest(
                    application_key=self.paig_plugin.get_application_key(),
                    client_application_key=self.paig_plugin.get_client_application_key(),
                    conversation_thread_id=self.paig_plugin.generate_conversation_thread_id(),
                    request_id=self.paig_plugin.generate_request_id(),
                    user_name=self.paig_plugin.get_current_user(),
                    request_text=request_text,
                    conversation_type=ConversationType.REPLY
                )

                access_result = self.paig_plugin.get_shield_client().is_access_allowed(access_request)

                # Throw exception if access is denied
                self.check_access(access_result)

                # We need to prepare updated arguments
                response_messages = access_result.get_response_messages()
                updated_s_generations = []
                i = 0
                for message in response_messages:
                    existing_s_generation = generation[i]

                    updated_s_generation = existing_s_generation.copy()
                    updated_s_generation.text = message.get_response_text()
                    updated_s_generations.append(updated_s_generation)

                    i = i + 1

                updated_generations.append(updated_s_generations)

            output.generations = updated_generations

        return output


class ChainCallCallback(MethodIOCallback):
    def __init__(self, paig_plugin, cls, method):
        """
        Initialize the ChainCallCallback instance.

        Args:
            paig_plugin: The base plugin for the callback.
            cls: The class to which the callback is applied.
            method: The method to which the callback is applied.
        """
        super().__init__(paig_plugin, cls, method)

    def init(self):
        self.paig_plugin.thread_local.chain_recursive_call_count = 0
        self.paig_plugin.thread_local.conversation_thread_id = ""

    def check_access(self, access_result):
        last_response_message = access_result.get_last_response_message()

        if not access_result.get_is_allowed():
            raise AccessControlException(last_response_message.get_response_text())

    def process_inputs(self, *args, **kwargs):
        """
        Process the input values from the method callback.

        Args:
            *args: Variable-length positional arguments.
            **kwargs: Variable-length keyword arguments.

        Returns:
            tuple: A tuple containing the updated *args and **kwargs.
        """
        if self.paig_plugin.thread_local.chain_recursive_call_count == 0:
            self.paig_plugin.thread_local.conversation_thread_id = self.paig_plugin.generate_conversation_thread_id()

            access_request = ShieldAccessRequest(
                application_key=self.paig_plugin.get_application_key(),
                client_application_key=self.paig_plugin.get_client_application_key(),
                conversation_thread_id=self.paig_plugin.thread_local.conversation_thread_id,
                request_id=self.paig_plugin.generate_request_id(),
                user_name=self.paig_plugin.get_current_user(),
                request_text=args[0]["question"],
                conversation_type=ConversationType.PROMPT
            )

            access_result = self.paig_plugin.get_shield_client().is_access_allowed(access_request)

            # Throw exception if access is denied
            self.check_access(access_result)

            # We need to prepare updated arguments
            last_response_message = access_result.get_last_response_message()
            args[0]["question"] = last_response_message.get_response_text()

        self.paig_plugin.thread_local.chain_recursive_call_count += 1

        return args, kwargs

    def process_output(self, output):
        """
        Process the output values from the method callback.

        Args:
            output: The output from the method callback.

        Returns:
            Any: The updated output.
        """
        self.paig_plugin.thread_local.chain_recursive_call_count -= 1

        if self.paig_plugin.thread_local.chain_recursive_call_count == 0:
            self.paig_plugin.thread_local.conversation_thread_id = ""

        return output


class LLMCallCallback(MethodIOCallback):
    """
    Callback that intercepts LLM._call()

    """

    def __init__(self, paig_plugin, cls, method):
        """
        Initialize the BaseLLMGenerateCallback instance.

        Args:
            paig_plugin: The base plugin for the callback.
            cls: The class to which the callback is applied.
            method: The method to which the callback is applied.
        """
        super().__init__(paig_plugin, cls, method)

    def init(self):
        pass

    def check_access(self, access_result):
        last_response_message = access_result.get_last_response_message()

        if not access_result.get_is_allowed():
            raise AccessControlException(last_response_message.get_response_text())

    def process_inputs(self, *args, **kwargs):
        """
        Process the input values from the method callback.

        Args:
            *args: Positional arguments passed to the method.
            **kwargs: Keyword arguments passed to the method.

        Returns:
            tuple: A tuple containing the updated args and kwargs.
        """

        _logger.debug(f"LLMCallCallback.process_inputs with args={args}, kwargs={kwargs}")
        access_request = ShieldAccessRequest(
            application_key=self.paig_plugin.get_application_key(),
            client_application_key=self.paig_plugin.get_client_application_key(),
            conversation_thread_id=self.paig_plugin.generate_conversation_thread_id(),
            request_id=self.paig_plugin.generate_request_id(),
            user_name=self.paig_plugin.get_current_user(),
            request_text=[args[0]],
            conversation_type=ConversationType.PROMPT
        )

        access_result = self.paig_plugin.get_shield_client().is_access_allowed(access_request)

        # Throw exception if access is denied
        self.check_access(access_result)

        # We need to prepare updated arguments
        response_messages = access_result.get_response_messages()
        updated_input_args = response_messages[0].response_text
        updated_args = (updated_input_args,) + args[1:]

        return updated_args, kwargs

    def process_output(self, output):
        """
        Process the output values from the method callback.

        Args:
            output: The output from the method callback.

        Returns:
            Any: The processed output.
        """
        access_request = ShieldAccessRequest(
            application_key=self.paig_plugin.get_application_key(),
            client_application_key=self.paig_plugin.get_client_application_key(),
            conversation_thread_id=self.paig_plugin.generate_conversation_thread_id(),
            request_id=self.paig_plugin.generate_request_id(),
            user_name=self.paig_plugin.get_current_user(),
            request_text=[output],
            conversation_type=ConversationType.REPLY
        )

        access_result = self.paig_plugin.get_shield_client().is_access_allowed(access_request)

        # Throw exception if access is denied
        self.check_access(access_result)

        # We need to prepare updated arguments
        response_messages = access_result.get_response_messages()
        return response_messages[0].response_text
