from typing import List, overload

from loguru import logger

from elluminate.beta.resources.base import BaseResource
from elluminate.beta.schemas import (
    BatchCreatePromptResponseRequest,
    BatchCreatePromptResponseStatus,
    CreatePromptResponseRequest,
    GenerationMetadata,
    LLMConfig,
    PromptResponse,
    PromptTemplate,
    TemplateVariables,
)
from elluminate.beta.schemas.template_variables_collection import TemplateVariablesCollection
from elluminate.utils import retry_request, run_async


class ResponsesResource(BaseResource):
    async def alist(
        self,
        prompt_template: PromptTemplate | None = None,
        template_variables: TemplateVariables | None = None,
    ) -> list[PromptResponse]:
        """Async version of list."""
        params = {}
        if prompt_template:
            params["prompt_template_id"] = prompt_template.id
        if template_variables:
            params["template_variables_id"] = template_variables.id

        response = await self._aget("responses", params=params)
        return [PromptResponse.model_validate(response) for response in response.json()["items"]]

    def list(
        self,
        prompt_template: PromptTemplate | None = None,
        template_variables: TemplateVariables | None = None,
    ) -> list[PromptResponse]:
        """Returns the responses belonging to a prompt template, a template variables, or both.

        Args:
            prompt_template (PromptTemplate | None): The prompt template to get responses for.
            template_variables (TemplateVariables | None): The template variables to get responses for.

        Returns:
            list[PromptResponse]: The list of prompt responses.

        """
        return run_async(self.alist)(prompt_template=prompt_template, template_variables=template_variables)

    @retry_request
    async def aadd(
        self,
        response: str,
        prompt_template: PromptTemplate,
        template_variables: TemplateVariables | None = None,
        metadata: LLMConfig | GenerationMetadata | None = None,
    ) -> PromptResponse:
        """Async version of add."""
        async with self._semaphore:
            if isinstance(metadata, LLMConfig):
                metadata = GenerationMetadata(llm_model_config=metadata)

            prompt_response = CreatePromptResponseRequest(
                prompt_template_id=prompt_template.id,
                response=response,
                template_variables_id=template_variables.id if template_variables else None,
                metadata=metadata,
            )

            server_response = await self._apost(
                "responses",
                json=prompt_response.model_dump(),
            )
            return PromptResponse.model_validate(server_response.json())

    def add(
        self,
        response: str,
        prompt_template: PromptTemplate,
        template_variables: TemplateVariables | None = None,
        metadata: LLMConfig | GenerationMetadata | None = None,
    ) -> PromptResponse:
        """Add a response to a prompt template.

        Args:
            response (str): The response to add.
            prompt_template (PromptTemplate): The prompt template to add the response to.
            template_variables (TemplateVariables | None): The template variables to use for the response.
            metadata (LLMConfig | GenerationMetadata | None): Optional metadata to associate with the response.

        Returns:
            PromptResponse: The newly created prompt response object.

        """
        return run_async(self.aadd)(
            response=response,
            prompt_template=prompt_template,
            template_variables=template_variables,
            metadata=metadata,
        )

    @retry_request
    async def agenerate(
        self,
        prompt_template: PromptTemplate,
        template_variables: TemplateVariables | None = None,
        llm_config: LLMConfig | None = None,
    ) -> PromptResponse:
        """Async version of generate."""
        async with self._semaphore:
            if llm_config is not None and llm_config.id is None:
                logger.warning("The LLM config id is None. Default LLM config will be used.")

            prompt_response = CreatePromptResponseRequest(
                prompt_template_id=prompt_template.id,
                template_variables_id=template_variables.id if template_variables else None,
                llm_config_id=llm_config.id if llm_config else None,
            )

            server_response = await self._apost(
                "responses",
                json=prompt_response.model_dump(),
            )
            return PromptResponse.model_validate(server_response.json())

    def generate(
        self,
        prompt_template: PromptTemplate,
        template_variables: TemplateVariables | None = None,
        llm_config: LLMConfig | None = None,
    ) -> PromptResponse:
        """Generate a response for a prompt template using an LLM.

        This method sends the prompt to an LLM for generation. If no LLM config is provided,
        the project's default LLM config will be used.

        Args:
            prompt_template (PromptTemplate): The prompt template to generate a response for.
            llm_config (LLMConfig | None): Optional LLM configuration to use for generation.
                If not provided, the project's default config will be used.
            template_variables (TemplateVariables | None): The template variables to use for the response.

        Returns:
            PromptResponse: The generated response object

        Raises:
            ValueError: If no template variables source is provided (either template_variables or template_variables_id)

        """
        return run_async(self.agenerate)(
            prompt_template=prompt_template,
            template_variables=template_variables,
            llm_config=llm_config,
        )

    @retry_request
    async def aadd_many(
        self,
        responses: List[str],
        prompt_template: PromptTemplate,
        template_variables: List[TemplateVariables],
        metadata: List[LLMConfig | GenerationMetadata | None] | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]:
        """Async version of add_many."""
        async with self._semaphore:
            len_responses = len(responses)
            len_template_variables = len(template_variables)
            _metadata = metadata if metadata is not None else [None] * len_responses

            len_metadata = len(_metadata)
            if not (len_template_variables == len_responses == len_metadata):
                raise ValueError(
                    f"All input lists must have the same length. Got {len_template_variables} for template_variables, "
                    f"{len_responses} for responses, and {len_metadata} for metadata."
                )
            prompt_response_ins = []
            for resp, tmp_var, md in zip(responses, template_variables, _metadata):
                if isinstance(md, LLMConfig):
                    md = GenerationMetadata(llm_model_config=md)

                prompt_response_ins.append(
                    CreatePromptResponseRequest(
                        prompt_template_id=prompt_template.id,
                        response=resp,
                        template_variables_id=tmp_var.id,
                        metadata=md,
                    )
                )

            batch_request = BatchCreatePromptResponseRequest(
                prompt_response_ins=prompt_response_ins,
            )

            return await self._abatch_create(
                path="responses/batches",
                batch_request=batch_request,
                batch_response_type=BatchCreatePromptResponseStatus,
                timeout=timeout,
            )

    def add_many(
        self,
        responses: List[str],
        prompt_template: PromptTemplate,
        template_variables: List[TemplateVariables],
        metadata: List[LLMConfig | GenerationMetadata | None] | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]:
        """Add multiple responses to a prompt template in bulk.

        Use this method when you have a list of responses to add, instead of adding them one by one with the add() method.

        Args:
            responses (list[str]): List of responses to add.
            prompt_template (PromptTemplate): The prompt template to add responses to.
            template_variables (list[TemplateVariables]): List of template variables for each response.
            metadata (list[LLMConfig | GenerationMetadata | None] | None): Optional list of metadata for each response.
            timeout (float | None): Timeout in seconds for API requests. Defaults to no timeout.

        Returns:
            list[PromptResponse]: List of newly created prompt response objects.

        """
        return run_async(self.aadd_many)(
            responses=responses,
            prompt_template=prompt_template,
            template_variables=template_variables,
            metadata=metadata,
            timeout=timeout,
        )

    # This function is necessary because we use overloads for the agenerate_many method and reference it in the generate_many method.
    # The TypeChecker would complain if we reference the "base" async version in the generate_many method, as there is no overloaded option for
    # the parameters used.
    async def _agenerate_many_impl(
        self,
        prompt_template: PromptTemplate,
        *,
        template_variables: List[TemplateVariables] | None = None,
        collection: TemplateVariablesCollection | None = None,
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]:
        assert any([template_variables, collection]), "Either template_variables or collection must be provided."
        assert not all([template_variables, collection]), "Cannot provide both template_variables and collection."

        if collection is not None:
            template_variables = await self._client.template_variables.alist(collection=collection)

        # This is just for the linter, the checks above should ensure this
        assert template_variables

        len_template_variables = len(template_variables)
        llm_configs = [llm_config] * len_template_variables

        prompt_response_ins = []
        for tmp_var, llm_conf in zip(template_variables, llm_configs):
            prompt_response_ins.append(
                CreatePromptResponseRequest(
                    prompt_template_id=prompt_template.id,
                    template_variables_id=tmp_var.id,
                    llm_config_id=llm_conf.id if llm_conf else None,
                )
            )

        batch_request = BatchCreatePromptResponseRequest(
            prompt_response_ins=prompt_response_ins,
        )
        return await self._abatch_create(
            path="responses/batches",
            batch_request=batch_request,
            batch_response_type=BatchCreatePromptResponseStatus,
            timeout=timeout,
        )

    @overload
    async def agenerate_many(
        self,
        prompt_template: PromptTemplate,
        *,
        template_variables: List[TemplateVariables],
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]: ...

    @overload
    async def agenerate_many(
        self,
        prompt_template: PromptTemplate,
        *,
        collection: TemplateVariablesCollection,
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]: ...

    @retry_request
    async def agenerate_many(
        self,
        prompt_template: PromptTemplate,
        *,
        template_variables: List[TemplateVariables] | None = None,
        collection: TemplateVariablesCollection | None = None,
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]:
        """Async version of generate_many."""
        return await self._agenerate_many_impl(
            prompt_template=prompt_template,
            template_variables=template_variables,
            collection=collection,
            llm_config=llm_config,
            timeout=timeout,
        )

    @overload
    def generate_many(
        self,
        prompt_template: PromptTemplate,
        *,
        template_variables: List[TemplateVariables],
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]: ...

    @overload
    def generate_many(
        self,
        prompt_template: PromptTemplate,
        *,
        collection: TemplateVariablesCollection,
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]: ...

    def generate_many(
        self,
        prompt_template: PromptTemplate,
        *,
        template_variables: List[TemplateVariables] | None = None,
        collection: TemplateVariablesCollection | None = None,
        llm_config: LLMConfig | None = None,
        timeout: float | None = None,
    ) -> List[PromptResponse]:
        """Generate multiple responses for a prompt template.

        Use this method when you have a list of responses to generate, instead of generating them one by one with the generate() method.

        Either `template_variables` or `collection` can be provided:
        - If `template_variables` is given, it will use the provided list of template variables for each response.
        - If `collection` is given, it will use the template variables from the specified collection.

        Args:
            prompt_template (PromptTemplate): The prompt template to use for generation.
            template_variables (list[TemplateVariables] | None): List of template variables for each response.
            collection (TemplateVariablesCollection | None): The collection to use for the template variables.
            llm_config (LLMConfig | None): Optional LLMConfig to use for generation.
            timeout (float): Timeout in seconds for API requests. Defaults to no timeout.

        Returns:
            list[PromptResponse]: List of newly created prompt response objects.

        """
        return run_async(self._agenerate_many_impl)(
            prompt_template=prompt_template,
            template_variables=template_variables,
            collection=collection,
            llm_config=llm_config,
            timeout=timeout,
        )

    async def adelete(self, prompt_response: PromptResponse) -> None:
        """Async version of delete."""
        await self._adelete(f"responses/{prompt_response.id}")

    def delete(self, prompt_response: PromptResponse) -> None:
        """Delete a prompt response.

        Args:
            prompt_response (PromptResponse): The prompt response to delete.

        """
        return run_async(self.adelete)(prompt_response)
