from typing import Any, Dict, List, Optional, Union

from mlflow import evaluation as mlflow_eval

from databricks.rag_eval import context, schemas
from databricks.rag_eval.config import assessment_config
from databricks.rag_eval.evaluation import entities
from databricks.rag_eval.mlflow import mlflow_utils
from databricks.rag_eval.utils import input_output_utils


class CallableBuiltinJudge:
    """
    Callable object that can be used to evaluate inputs against
    the LLM judge service with the current assessment config.

    Args:
        config: The assessment config to use for the judge.
        assessment_name: Optional assessment override. If present, the result Assessment will have this name instead of the original built-in judge's name.
    """

    def __init__(
        self,
        config: assessment_config.BuiltinAssessmentConfig,
        assessment_name: Optional[str] = None,
    ):
        self.config = config
        self.assessment_name = assessment_name

    @context.eval_context
    def __call__(
        self,
        *,
        request: str | Dict[str, Any],
        response: Optional[str | Dict[str, Any]] = None,
        retrieved_context: Optional[List[Dict[str, Any]]] = None,
        expected_response: Optional[str] = None,
        expected_retrieved_context: Optional[List[Dict[str, Any]]] = None,
        expected_facts: Optional[List[str]] = None,
        guidelines: Optional[Union[List[str], Dict[str, List[str]]]] = None,
    ) -> Union[mlflow_eval.Assessment, List[mlflow_eval.Assessment]]:
        input_output_utils.request_to_string(request)
        input_output_utils.response_to_string(response)
        _validate_retrieved_context("retrieved_context", retrieved_context)
        _validate_string_input("expected_response", expected_response)
        _validate_retrieved_context(
            "expected_retrieved_context", expected_retrieved_context
        )
        _validate_repeated_string_input("expected_facts", expected_facts)
        _validate_guidelines(guidelines)

        managed_rag_client = context.get_context().build_managed_rag_client()
        eval_item = entities.EvalItem.from_dict(
            {
                schemas.REQUEST_COL: request,
                schemas.RESPONSE_COL: response,
                schemas.RETRIEVED_CONTEXT_COL: retrieved_context,
                schemas.EXPECTED_RESPONSE_COL: expected_response,
                schemas.EXPECTED_RETRIEVED_CONTEXT_COL: expected_retrieved_context,
                schemas.EXPECTED_FACTS_COL: expected_facts,
                schemas.GUIDELINES_COL: guidelines,
            }
        )

        # Guideline adherence requires special processing due to named guidelines
        if self.config == assessment_config.GUIDELINE_ADHERENCE:
            guideline_adherence_assessment_name = (
                assessment_config.GUIDELINE_ADHERENCE.assessment_name
            )
            is_named_guidelines = not (
                eval_item.named_guidelines is not None
                and len(eval_item.named_guidelines) == 1
                and guideline_adherence_assessment_name in eval_item.named_guidelines
            )

            assessments = []
            for guidelines_name, grouped_guidelines in (
                eval_item.named_guidelines or {}
            ).items():
                # Replace the named guidelines for each eval with the respective group's guidelines
                guidelines_eval_item = eval_item.as_dict()
                guidelines_eval_item[schemas.GUIDELINES_COL] = grouped_guidelines

                overall_assessment_name = (
                    self.assessment_name
                    or assessment_config.GUIDELINE_ADHERENCE.assessment_name
                )
                # Use two-tiered name for named guidelines
                user_facing_assessment_name = (
                    f"{overall_assessment_name}/{guidelines_name}"
                    if is_named_guidelines
                    else overall_assessment_name
                )

                assessment_result = managed_rag_client.get_assessment(
                    eval_item=entities.EvalItem.from_dict(guidelines_eval_item),
                    # Replace the user-facing name for each guidelines assessment
                    config=assessment_config.BuiltinAssessmentConfig(
                        assessment_name=guideline_adherence_assessment_name,
                        user_facing_assessment_name=user_facing_assessment_name,
                        assessment_type=assessment_config.AssessmentType.ANSWER,
                    ),
                )[0]
                assessments.append(
                    mlflow_utils.assessment_result_to_mlflow_assessments(
                        assessment_result, user_facing_assessment_name
                    )
                )
            return assessments if is_named_guidelines else assessments[0]

        else:
            assessment_result = managed_rag_client.get_assessment(
                eval_item=eval_item,
                config=self.config,
            )[0]
            return mlflow_utils.assessment_result_to_mlflow_assessments(
                assessment_result, self.assessment_name
            )


def _validate_string_input(param_name: str, input_value: Any) -> None:
    if input_value and not isinstance(input_value, str):
        raise ValueError(f"{param_name} must be a string. Got: {type(input_value)}")


def _validate_retrieved_context(
    param_name: str, retrieved_context: Optional[List[Dict[str, Any]]]
) -> None:
    if retrieved_context:
        if not isinstance(retrieved_context, list):
            raise ValueError(
                f"{param_name} must be a list of dictionaries. Got: {type(retrieved_context)}"
            )
        for context_dict in retrieved_context:
            if not isinstance(context_dict, dict):
                raise ValueError(
                    f"{param_name} must be a list of dictionaries. Got list of: {type(context_dict)}"
                )
            if "content" not in context_dict:
                raise ValueError(
                    f"Each context in {param_name} must have a 'content' key. Got: {context_dict}"
                )
            if set(context_dict.keys()) - {"doc_uri", "content"}:
                raise ValueError(
                    f"Each context in {param_name} must have only 'doc_uri' and 'content' keys. Got: {context_dict}"
                )


def _validate_repeated_string_input(param_name: str, input_value: Any) -> None:
    if input_value is None:
        return
    elif not isinstance(input_value, list):
        raise ValueError(f"{param_name} must be a list. Got: {type(input_value)}")

    for idx, value in enumerate(input_value):
        if not isinstance(value, str):
            raise ValueError(
                f"{param_name} must be a list of strings. Got: {type(value)} at index: {idx}"
            )


def _validate_guidelines(guidelines: Optional[Dict[str, List[str]]]) -> None:
    if guidelines is None:
        return

    guidelines_is_valid_iterable = input_output_utils.is_valid_guidelines_iterable(
        guidelines
    )
    guidelines_is_valid_mapping = input_output_utils.is_valid_guidelines_mapping(
        guidelines
    )

    if not (guidelines_is_valid_iterable or guidelines_is_valid_mapping):
        raise ValueError(
            f"Invalid guidelines: {guidelines}. Guidelines must be a list of strings "
            f"or a mapping from a name of guidelines (string) to a list of strings."
        )
    elif guidelines_is_valid_iterable:
        input_output_utils.check_guidelines_iterable_exceeds_limit(guidelines)
    elif guidelines_is_valid_mapping:
        input_output_utils.check_guidelines_mapping_exceeds_limit(guidelines)


# use this docstring for the CallableBuiltinJudge class
CALLABLE_BUILTIN_JUDGE_DOCSTRING = """
        {judge_description}

        Args:
            request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
            response: Response generated by the application being evaluated.
            retrieved_context: Retrieval results generated by the retriever in the application being evaluated. 
                It should be a list of dictionaries with the following keys:
                    - doc_uri (Optional): The doc_uri of the context.
                    - content: The content of the context.
            expected_response: Ground-truth (correct) answer for the input request.
            expected_retrieved_context: Array of objects containing the expected retrieved context for the request 
                (if the application includes a retrieval step). It should be a list of dictionaries with the
                following keys:
                    - doc_uri (Optional): The doc_uri of the context.
                    - content: The content of the context.
            expected_facts: Array of strings containing facts expected in the correct response for the input request.
            guidelines: Array of strings containing the guidelines that the response should adhere to.
        Required input arguments:
            {required_args}

        Returns:
            Assessment result for the given input.
        """


# =================== Builtin Judges ===================
def correctness(
    request: str | Dict[str, Any],
    response: str | Dict[str, Any],
    expected_response: Optional[str] = None,
    expected_facts: Optional[List[str]] = None,
    assessment_name: Optional[str] = None,
) -> mlflow_eval.Assessment:
    """
    The correctness LLM judge gives a binary evaluation and written rationale on whether the
    response generated by the agent is factually accurate and semantically similar to the provided
    expected response or expected facts.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        response: Response generated by the application being evaluated.
        expected_response: Ground-truth (correct) answer for the input request.
        expected_facts: Array of strings containing facts expected in the correct response for the input request.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "correctness"
    Required input arguments:
        request, response, oneof(expected_response, expected_facts)

    Returns:
        Correctness assessment result for the given input.
    """
    return CallableBuiltinJudge(
        config=assessment_config.CORRECTNESS, assessment_name=assessment_name
    )(
        request=request,
        response=response,
        expected_response=expected_response,
        expected_facts=expected_facts,
    )


def groundedness(
    request: str | Dict[str, Any],
    response: str | Dict[str, Any],
    retrieved_context: List[Dict[str, Any]],
    assessment_name: Optional[str] = None,
) -> mlflow_eval.Assessment:
    """
    The groundedness LLM judge returns a binary evaluation and written rationale on whether the
    generated response is factually consistent with the retrieved context.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        response: Response generated by the application being evaluated.
        retrieved_context: Retrieval results generated by the retriever in the application being evaluated.
                It should be a list of dictionaries with the following keys:
                    - doc_uri (Optional): The doc_uri of the context.
                    - content: The content of the context.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "groundedness"
    Required input arguments:
        request, response, retrieved_context

    Returns:
        Groundedness assessment result for the given input.
    """
    return CallableBuiltinJudge(
        config=assessment_config.GROUNDEDNESS, assessment_name=assessment_name
    )(
        request=request,
        response=response,
        retrieved_context=retrieved_context,
    )


def safety(
    request: str | Dict[str, Any],
    response: str | Dict[str, Any],
    assessment_name: Optional[str] = None,
) -> mlflow_eval.Assessment:
    """
    The safety LLM judge returns a binary rating and a written rationale on whether the generated
    response has harmful or toxic content.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        response: Response generated by the application being evaluated.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "safety"
    Required input arguments:
        request, response

    Returns:
        Safety assessment result for the given input.
    """
    return CallableBuiltinJudge(
        config=assessment_config.HARMFULNESS, assessment_name=assessment_name
    )(
        request=request,
        response=response,
    )


def relevance_to_query(
    request: str | Dict[str, Any],
    response: str | Dict[str, Any],
    assessment_name: Optional[str] = None,
) -> mlflow_eval.Assessment:
    """
    The relevance_to_query LLM judge determines whether the response is relevant to the input request.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        response: Response generated by the application being evaluated.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "relevance_to_query"
    Required input arguments:
        request, response

    Returns:
        Relevance to query assessment result for the given input.
    """
    return CallableBuiltinJudge(
        config=assessment_config.RELEVANCE_TO_QUERY, assessment_name=assessment_name
    )(
        request=request,
        response=response,
    )


def chunk_relevance(
    request: str | Dict[str, Any],
    retrieved_context: List[Dict[str, Any]],
    assessment_name: Optional[str] = None,
) -> List[mlflow_eval.Assessment]:
    """
    The chunk-relevance-precision LLM judge determines whether the chunks returned by the retriever
    are relevant to the input request. Precision is calculated as the number of relevant chunks
    returned divided by the total number of chunks returned. For example, if the retriever returns
    four chunks, and the LLM judge determines that three of the four returned documents are relevant
    to the request, then llm_judged/chunk_relevance/precision is 0.75.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        retrieved_context: Retrieval results generated by the retriever in the application being evaluated.
                It should be a list of dictionaries with the following keys:
                    - doc_uri (Optional): The doc_uri of the context.
                    - content: The content of the context.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "chunk_relevance"
    Required input arguments:
        request, retrieved_context

    Returns:
        Chunk relevance assessment result for each of the chunks in the given input.
    """
    return CallableBuiltinJudge(
        config=assessment_config.CHUNK_RELEVANCE, assessment_name=assessment_name
    )(
        request=request,
        retrieved_context=retrieved_context,
    )


def context_sufficiency(
    request: str | Dict[str, Any],
    retrieved_context: List[Dict[str, Any]],
    expected_response: Optional[str] = None,
    expected_facts: Optional[List[str]] = None,
    assessment_name: Optional[str] = None,
) -> mlflow_eval.Assessment:
    """
    The context_sufficiency LLM judge determines whether the retriever has retrieved documents that are
    sufficient to produce the expected response or expected facts.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        expected_response: Ground-truth (correct) answer for the input request.
        retrieved_context: Retrieval results generated by the retriever in the application being evaluated.
                It should be a list of dictionaries with the following keys:
                    - doc_uri (Optional): The doc_uri of the context.
                    - content: The content of the context.
        expected_facts: Array of strings containing facts expected in the correct response for the input request.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "context_sufficiency"
    Required input arguments:
        request, retrieved_context, oneof(expected_response, expected_facts)

    Returns:
        Context sufficiency assessment result for the given input.
    """
    return CallableBuiltinJudge(
        config=assessment_config.CONTEXT_SUFFICIENCY, assessment_name=assessment_name
    )(
        request=request,
        retrieved_context=retrieved_context,
        expected_response=expected_response,
        expected_facts=expected_facts,
    )


def guideline_adherence(
    request: str | Dict[str, Any],
    response: str | Dict[str, Any],
    guidelines: Union[List[str], Dict[str, List[str]]],
    assessment_name: Optional[str] = None,
) -> Union[mlflow_eval.Assessment, List[mlflow_eval.Assessment]]:
    """
    The guideline_adherence LLM judge determines whether the response to the request adheres to the
    provided guidelines.

    Args:
        request: Input to the application to evaluate, user’s question or query. For example, “What is RAG?”.
        response: Response generated by the application being evaluated.
        guidelines: One of the following:
         - Array of strings containing the guidelines that the response should adhere to.
         - Mapping of string (named guidelines) to array of strings containing the guidelines the response should adhere to.
        assessment_name: Optional override for the assessment name.  If present, the output Assessment will use this as the name instead of "guideline_adherence"
    Required input arguments:
        request, response, guidelines

    Returns:
        Guideline adherence assessment(s) result for the given input. Returns a list when named guidelines are provided.
    """
    return CallableBuiltinJudge(
        config=assessment_config.GUIDELINE_ADHERENCE, assessment_name=assessment_name
    )(
        request=request,
        response=response,
        guidelines=guidelines,
    )
