"""Entities for evaluation."""

import dataclasses
import functools
import hashlib
import json
from collections import abc
from copy import deepcopy
from typing import Any, Collection, Dict, List, Mapping, Optional, TypeAlias, Union

import mlflow.entities as mlflow_entities
import pandas as pd
from mlflow import evaluation as mlflow_eval

from databricks.rag_eval import constants, schemas
from databricks.rag_eval.config import (
    assessment_config,
    example_config,
)
from databricks.rag_eval.utils import (
    collection_utils,
    enum_utils,
    input_output_utils,
    serialization_utils,
)

ChunkInputData = Union[str, Dict[str, Any]]
RetrievalContextInputData = List[Optional[ChunkInputData]]

_CHUNK_INDEX_KEY = "chunk_index"


@dataclasses.dataclass
class Chunk:
    doc_uri: Optional[str] = None
    content: Optional[str] = None

    @classmethod
    def from_input_data(cls, input_data: Optional[ChunkInputData]) -> Optional["Chunk"]:
        """
        Construct a Chunk from a dictionary optionally containing doc_uri and content.

        An input chunk of a retrieval context can be:
          - A doc URI; or
          - A dictionary with the schema defined in schemas.CHUNK_SCHEMA
        """
        if input_output_utils.is_none_or_nan(input_data):
            return None
        if isinstance(input_data, str):
            return cls(doc_uri=input_data)
        else:
            return cls(
                doc_uri=input_data.get(schemas.DOC_URI_COL),
                content=input_data.get(schemas.CHUNK_CONTENT_COL),
            )


class RetrievalContext(List[Optional[Chunk]]):
    def __init__(self, chunks: Collection[Optional[Chunk]]):
        super().__init__(chunks)

    def concat_chunk_content(
        self, delimiter: str = constants.DEFAULT_CONTEXT_CONCATENATION_DELIMITER
    ) -> Optional[str]:
        """
        Concatenate the non-empty content of the chunks to a string with the given delimiter.
        Return None if all the contents are empty.
        """
        non_empty_contents = [
            chunk.content for chunk in self if chunk is not None and chunk.content
        ]
        return delimiter.join(non_empty_contents) if non_empty_contents else None

    def get_doc_uris(self) -> List[Optional[str]]:
        """Get the list of doc URIs in the retrieval context."""
        return [chunk.doc_uri for chunk in self if chunk is not None]

    def to_output_dict(self) -> List[Dict[str, str]]:
        """Convert the RetrievalContext to a list of dictionaries with the schema defined in schemas.CHUNK_SCHEMA."""
        return [
            (
                {
                    schemas.DOC_URI_COL: chunk.doc_uri,
                    schemas.CHUNK_CONTENT_COL: chunk.content,
                }
                if chunk is not None
                else None
            )
            for chunk in self
        ]

    @classmethod
    def from_input_data(
        cls, input_data: Optional[RetrievalContextInputData]
    ) -> Optional["RetrievalContext"]:
        """
        Construct a RetrievalContext from the input.

        Input can be:
        - A list of doc URIs
        - A list of dictionaries with the schema defined in schemas.CHUNK_SCHEMA
        """
        if input_output_utils.is_none_or_nan(input_data):
            return None
        return cls([Chunk.from_input_data(chunk_data) for chunk_data in input_data])


@dataclasses.dataclass
class ToolCallInvocation:
    tool_name: str
    tool_call_args: Dict[str, Any]
    tool_call_id: Optional[str] = None
    tool_call_result: Optional[Dict[str, Any]] = None

    # Only available from the trace
    raw_span: Optional[mlflow_entities.Span] = None
    available_tools: Optional[List[Dict[str, Any]]] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            "tool_name": self.tool_name,
            "tool_call_args": self.tool_call_args,
            "tool_call_id": self.tool_call_id,
            "tool_call_result": self.tool_call_result,
            "raw_span": self.raw_span,
            "available_tools": self.available_tools,
        }

    @classmethod
    def _from_dict(cls, data: Dict[str, Any]) -> "ToolCallInvocation":
        return cls(
            tool_name=data["tool_name"],
            tool_call_args=data.get("tool_call_args", {}),
            tool_call_id=data.get("tool_call_id"),
            tool_call_result=data.get("tool_call_result"),
            raw_span=data.get("raw_span"),
            available_tools=data.get("available_tools"),
        )

    @classmethod
    def from_dict(
        cls, tool_calls: Optional[List[Dict[str, Any]] | Dict[str, Any]]
    ) -> Optional["ToolCallInvocation" | List["ToolCallInvocation"]]:
        if tool_calls is None:
            return None
        if isinstance(tool_calls, dict):
            return cls._from_dict(tool_calls)
        elif isinstance(tool_calls, list):
            return [cls._from_dict(tool_call) for tool_call in tool_calls]
        else:
            raise ValueError(
                f"Expected `tool_calls` to be a `dict` or `List[dict]`, but got: {type(tool_calls)}"
            )


class CategoricalRating(enum_utils.StrEnum):
    """A categorical rating for an assessment."""

    YES = "yes"
    NO = "no"
    UNKNOWN = "unknown"

    @classmethod
    def _missing_(cls, value: str):
        value = value.lower()
        for member in cls:
            if member == value:
                return member
        return cls.UNKNOWN

    @classmethod
    def from_example_rating(
        cls, rating: example_config.ExampleRating
    ) -> "CategoricalRating":
        """Convert an ExampleRating to a CategoricalRating."""
        match rating:
            case example_config.ExampleRating.YES:
                return cls.YES
            case example_config.ExampleRating.NO:
                return cls.NO
            case _:
                return cls.UNKNOWN


@dataclasses.dataclass
class Rating:
    double_value: Optional[float]
    rationale: Optional[str]
    categorical_value: Optional[CategoricalRating]
    error_message: Optional[str]
    error_code: Optional[str]

    @classmethod
    def value(
        cls,
        *,
        rationale: Optional[str] = None,
        double_value: Optional[float] = None,
        categorical_value: Optional[CategoricalRating | str] = None,
    ) -> "Rating":
        """Build a normal Rating with a categorical value, a double value, and a rationale."""
        if categorical_value is not None and not isinstance(
            categorical_value, CategoricalRating
        ):
            categorical_value = CategoricalRating(categorical_value)
        return cls(
            double_value=double_value,
            rationale=rationale,
            categorical_value=categorical_value,
            error_message=None,
            error_code=None,
        )

    @classmethod
    def error(
        cls, error_message: str, error_code: Optional[str | int] = None
    ) -> "Rating":
        """Build an error Rating with an error message and an optional error code."""
        if isinstance(error_code, int):
            error_code = str(error_code)
        return cls(
            double_value=None,
            rationale=None,
            categorical_value=None,
            error_message=error_message,
            error_code=error_code or "UNKNOWN",
        )

    @classmethod
    def flip(cls, rating: "Rating") -> "Rating":
        """Built a Rating with the inverse categorical and float values of the input Rating."""
        if rating.double_value is not None and (
            rating.double_value < 1.0 or rating.double_value > 5.0
        ):
            raise ValueError(
                f"Cannot flip the rating of double value: {rating.double_value}."
            )

        match rating.categorical_value:
            case CategoricalRating.YES:
                flipped_categorical_value = CategoricalRating.NO
                flipped_double_value = 1.0
            case CategoricalRating.NO:
                flipped_categorical_value = CategoricalRating.YES
                flipped_double_value = 5.0
            case CategoricalRating.UNKNOWN:
                flipped_categorical_value = CategoricalRating.UNKNOWN
                flipped_double_value = None
            case None:
                flipped_categorical_value = None
                flipped_double_value = None
            case _:
                raise ValueError(
                    f"Cannot flip the rating of categorical value: {rating.categorical_value}"
                )

        return cls(
            double_value=flipped_double_value,
            rationale=rating.rationale,
            categorical_value=flipped_categorical_value,
            error_message=rating.error_message,
            error_code=rating.error_code,
        )


PositionalRating: TypeAlias = Mapping[int, Rating]
"""
A mapping from position to rating.
Position refers to the position of the chunk in the retrieval context.
It is used to represent the ratings of the chunks in the retrieval context.
"""


@functools.total_ordering
@dataclasses.dataclass
class EvalItem:
    """
    Represents a row in the evaluation dataset. It contains information needed to evaluate a question.
    """

    question_id: str
    """Unique identifier for the eval item."""

    raw_request: Any
    """Raw input to the agent when `evaluate` is called. Comes from "request" or "inputs" columns. """

    raw_response: Any
    """Raw output from an agent."""

    has_inputs_outputs: bool = False
    """Whether the eval item used the new inputs/outputs columns, or the old request/response columns."""

    question: Optional[str] = None
    """String representation of the model input that is used for evaluation."""

    answer: Optional[str] = None
    """String representation of the model output that is used for evaluation."""

    retrieval_context: Optional[RetrievalContext] = None
    """Retrieval context that is used for evaluation."""

    ground_truth_answer: Optional[str] = None
    """String representation of the ground truth answer."""

    ground_truth_retrieval_context: Optional[RetrievalContext] = None
    """Ground truth retrieval context."""

    grading_notes: Optional[str] = None
    """String representation of the grading notes."""

    expected_facts: Optional[List[str]] = None
    """List of expected facts to help evaluate the answer."""

    guidelines: Optional[List[str]] = None
    """[INTERNAL ONLY] List of guidelines the response must adhere to used for the judge service."""

    named_guidelines: Optional[Dict[str, List[str]]] = None
    """Mapping of name to guidelines the response must adhere to."""

    custom_expected: Optional[Dict[str, Any]] = None
    """Custom expected data to help evaluate the answer."""

    custom_inputs: Optional[Dict[str, Any]] = None
    """Custom expected data to help evaluate the answer."""

    custom_outputs: Optional[Dict[str, Any]] = None
    """Custom expected data to help evaluate the answer."""

    trace: Optional[mlflow_entities.Trace] = None
    """Trace of the model invocation."""

    tool_calls: Optional[List[ToolCallInvocation]] = None
    """List of tool call invocations from an agent."""

    managed_evals_eval_id: Optional[str] = None
    """Unique identifier for the managed-evals eval item."""

    managed_evals_dataset_id: Optional[str] = None
    """Unique identifier for the managed-evals dataset."""

    model_error_message: Optional[str] = None
    """Error message if the model invocation fails."""

    source_id: Optional[str] = None
    """
    The source for this eval row. If source_type is "HUMAN", then user email.
    If source_type is "SYNTHETIC_FROM_DOC", then the doc URI.
    """

    source_type: Optional[str] = None
    """Source of the eval item. e.g. HUMAN, SYNTHETIC_FROM_DOC, PRODUCTION_LOG..."""

    tags: Optional[Dict[str, str]] = None
    """Tags associated with the eval item."""

    @property
    def concatenated_retrieval_context(self) -> Optional[str]:
        """Get the concatenated content of the retrieval context.
        Return None if there is no non-empty retrieval context content."""
        return (
            self.retrieval_context.concat_chunk_content()
            if self.retrieval_context
            else None
        )

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "EvalItem":
        """
        Create an EvalItem from a row of MLflow EvaluationDataset data.
        """
        retrieved_context = RetrievalContext.from_input_data(
            data.get(schemas.RETRIEVED_CONTEXT_COL)
        )

        expected_retrieved_context = RetrievalContext.from_input_data(
            data.get(schemas.EXPECTED_RETRIEVED_CONTEXT_COL)
        )

        has_inputs_outputs = False
        # Set the question/raw_request
        try:
            # Get the raw request from "request" or "inputs".
            request_obj = data.get(schemas.REQUEST_COL)
            if not request_obj:
                request_obj = data.get(schemas.INPUTS_COL)
                has_inputs_outputs = True
                if isinstance(request_obj, str):
                    # Deseralize the "inputs" json string into dict[str, Any].
                    request_obj: dict[str, Any] = json.loads(request_obj)
            raw_request = input_output_utils.parse_variant_data(request_obj)
        except Exception as e:
            raise ValueError(
                f"The request object must be JSON serializable: {type(request_obj)}"
            ) from e
        try:
            question = input_output_utils.request_to_string(raw_request)
        except ValueError:
            question = None

        # Set the question id
        question_id = data.get(schemas.REQUEST_ID_COL)
        if input_output_utils.is_none_or_nan(question_id):
            question_id = hashlib.sha256(str(raw_request).encode()).hexdigest()

        # Set the answer/raw_response
        try:
            # Get the raw response from "response" or "outputs".
            response_obj = data.get(schemas.RESPONSE_COL)
            if not response_obj:
                response_obj = data.get(schemas.OUTPUTS_COL)
                if isinstance(response_obj, str):
                    # Deseralize the json string into dict[str, Any].
                    response_obj: dict[str, Any] = json.loads(response_obj)
            raw_response = input_output_utils.parse_variant_data(response_obj)
        except Exception as e:
            raise ValueError(
                f"The response object must be JSON serializable: {type(response_obj)}"
            ) from e
        try:
            answer = input_output_utils.response_to_string(raw_response)
        except ValueError:
            answer = None

        ground_truth_answer = input_output_utils.response_to_string(
            data.get(schemas.EXPECTED_RESPONSE_COL)
        )

        grading_notes = data.get(schemas.GRADING_NOTES_COL)
        grading_notes = (
            grading_notes
            if not input_output_utils.is_none_or_nan(grading_notes)
            else None
        )

        expected_facts = data.get(schemas.EXPECTED_FACTS_COL)
        expected_facts = (
            list(expected_facts)
            if not input_output_utils.is_none_or_nan(expected_facts)
            else None
        )

        named_guidelines = data.get(schemas.GUIDELINES_COL)

        # Extract the relevant expected data from "custom_expected" or "expectations".
        custom_expected = data.get(schemas.CUSTOM_EXPECTED_COL)
        if not custom_expected and schemas.EXPECTATIONS_COL in data:
            expectations = deepcopy(data.get(schemas.EXPECTATIONS_COL))
            if input_output_utils.is_none_or_nan(expectations):
                expectations = {}
            if isinstance(expectations, str):
                # Deseralize the json string into dict[str, Any].
                expectations: dict[str, Any] = json.loads(expectations)
            # Extract the built-in expectations from the expectations.
            if (
                ground_truth_answer is None
                and schemas.EXPECTED_RESPONSE_COL in expectations
            ):
                ground_truth_answer = expectations.pop(schemas.EXPECTED_RESPONSE_COL)
            if expected_facts is None and schemas.EXPECTED_FACTS_COL in expectations:
                expected_facts = expectations.pop(schemas.EXPECTED_FACTS_COL)
            if named_guidelines is None and schemas.GUIDELINES_COL in expectations:
                named_guidelines = expectations.pop(schemas.GUIDELINES_COL)
            if (
                expected_retrieved_context is None
                and schemas.EXPECTED_RETRIEVED_CONTEXT_COL in expectations
            ):
                expected_retrieved_context = RetrievalContext.from_input_data(
                    expectations.pop(schemas.EXPECTED_RETRIEVED_CONTEXT_COL)
                )
            # If the dict is empty, set to None to avoid creating an empty output column.
            custom_expected = expectations or None

        guidelines = None  # These are only used to pass to the judge
        if input_output_utils.is_none_or_nan(named_guidelines):
            guidelines = None
            named_guidelines = None
        elif isinstance(named_guidelines, abc.Iterable) and not isinstance(
            named_guidelines, Mapping
        ):
            # When an iterable (e.g., list or numpy array) is passed, we can use these guidelines
            # for the judge service. We cannot use a mapping.
            guidelines = list(named_guidelines)
            # Convert an iterable of guidelines to a default mapping
            named_guidelines = {
                assessment_config.GUIDELINE_ADHERENCE.assessment_name: list(
                    named_guidelines
                )
            }

        custom_inputs = None
        if isinstance(raw_request, dict):
            custom_inputs = raw_request.get(schemas.CUSTOM_INPUTS_COL)
        if isinstance(custom_inputs, str):
            custom_inputs = json.loads(custom_inputs)

        custom_outputs = None
        if isinstance(raw_response, dict):
            custom_outputs = raw_response.get(schemas.CUSTOM_OUTPUTS_COL)
        if isinstance(custom_outputs, str):
            custom_outputs = json.loads(custom_outputs)

        trace = data.get(schemas.TRACE_COL)
        if input_output_utils.is_none_or_nan(trace):
            trace = None
        else:
            trace = serialization_utils.deserialize_trace(trace)

        tool_calls = data.get(schemas.TOOL_CALLS_COL)
        if isinstance(tool_calls, list):
            tool_calls = [
                ToolCallInvocation.from_dict(tool_call) for tool_call in tool_calls
            ]
        elif isinstance(tool_calls, dict):
            tool_calls = [ToolCallInvocation.from_dict(tool_calls)]

        source_id = data.get(schemas.SOURCE_ID_COL)
        source_type = data.get(schemas.SOURCE_TYPE_COL)

        managed_evals_eval_id = data.get(schemas.MANAGED_EVALS_EVAL_ID_COL)
        managed_evals_dataset_id = data.get(schemas.MANAGED_EVALS_DATASET_ID_COL)
        tags = data.get(schemas.TAGS_COL)

        return cls(
            question_id=question_id,
            question=question,
            raw_request=raw_request,
            has_inputs_outputs=has_inputs_outputs,
            answer=answer,
            raw_response=raw_response,
            retrieval_context=retrieved_context,
            ground_truth_answer=ground_truth_answer,
            ground_truth_retrieval_context=expected_retrieved_context,
            grading_notes=grading_notes,
            expected_facts=expected_facts,
            guidelines=guidelines,
            named_guidelines=named_guidelines,
            custom_expected=custom_expected,
            custom_inputs=custom_inputs,
            custom_outputs=custom_outputs,
            trace=trace,
            tool_calls=tool_calls,
            source_id=source_id,
            source_type=source_type,
            managed_evals_eval_id=managed_evals_eval_id,
            managed_evals_dataset_id=managed_evals_dataset_id,
            tags=tags,
        )

    def as_dict(
        self, *, use_chat_completion_request_format: bool = False
    ) -> Dict[str, Any]:
        """
        Get as a dictionary. Keys are defined in schemas. Exclude None values.

        :param use_chat_completion_request_format: Whether to use the chat completion request format for the request.
        """
        request = self.raw_request or self.question
        if use_chat_completion_request_format:
            request = input_output_utils.to_chat_completion_request(self.question)
        response = self.raw_response or self.answer

        # When returning the guidelines, ensure they are returned in the format they are given.
        # In other words, revert the default mapping we create when a list of guidelines is provided.
        is_named_guidelines = not (
            self.named_guidelines is not None
            and len(self.named_guidelines) == 1
            and assessment_config.GUIDELINE_ADHERENCE.assessment_name
            in self.named_guidelines
        )
        guidelines = (
            self.guidelines
            if not is_named_guidelines and self.guidelines is not None
            else self.named_guidelines
        )

        inputs = {
            schemas.REQUEST_ID_COL: self.question_id,
            # input
            schemas.REQUEST_COL: request,
            schemas.CUSTOM_INPUTS_COL: self.custom_inputs,
            # output
            schemas.RESPONSE_COL: response,
            schemas.RETRIEVED_CONTEXT_COL: (
                self.retrieval_context.to_output_dict()
                if self.retrieval_context
                else None
            ),
            schemas.CUSTOM_OUTPUTS_COL: self.custom_outputs,
            schemas.TRACE_COL: serialization_utils.serialize_trace(self.trace),
            schemas.TOOL_CALLS_COL: (
                [ToolCallInvocation.to_dict(tool_call) for tool_call in self.tool_calls]
                if self.tool_calls is not None
                else None
            ),
            schemas.MODEL_ERROR_MESSAGE_COL: self.model_error_message,
            # expected
            schemas.EXPECTED_RETRIEVED_CONTEXT_COL: (
                self.ground_truth_retrieval_context.to_output_dict()
                if self.ground_truth_retrieval_context
                else None
            ),
            schemas.EXPECTED_RESPONSE_COL: self.ground_truth_answer,
            schemas.GRADING_NOTES_COL: self.grading_notes,
            schemas.EXPECTED_FACTS_COL: self.expected_facts,
            schemas.GUIDELINES_COL: guidelines,
            schemas.CUSTOM_EXPECTED_COL: self.custom_expected,
            # source related
            schemas.SOURCE_TYPE_COL: self.source_type,
            schemas.SOURCE_ID_COL: self.source_id,
            schemas.MANAGED_EVALS_EVAL_ID_COL: self.managed_evals_eval_id,
            schemas.MANAGED_EVALS_DATASET_ID_COL: self.managed_evals_dataset_id,
            schemas.TAGS_COL: self.tags,
        }
        return collection_utils.drop_none_values(inputs)

    def __eq__(self, other):
        if not hasattr(other, "question_id"):
            return NotImplemented
        return self.question_id == other.question_id

    def __lt__(self, other):
        if not hasattr(other, "question_id"):
            return NotImplemented
        return self.question_id < other.question_id


@dataclasses.dataclass
class AssessmentSource:
    source_id: str

    @classmethod
    def builtin(cls) -> "AssessmentSource":
        return cls(
            source_id="databricks",
        )

    @classmethod
    def custom(cls) -> "AssessmentSource":
        return cls(
            source_id="custom",
        )


@dataclasses.dataclass(frozen=True, eq=True)
class AssessmentResult:
    """Holds the result of an assessment."""

    assessment_name: str
    assessment_type: assessment_config.AssessmentType
    assessment_source: AssessmentSource

    def __lt__(self, other):
        if not isinstance(other, AssessmentResult):
            return NotImplemented
        # Compare by assessment_name
        return self.assessment_name < other.assessment_name


@dataclasses.dataclass(frozen=True, eq=True)
class PerRequestAssessmentResult(AssessmentResult):
    """Holds the result of a per-request assessment."""

    rating: Rating
    assessment_type: assessment_config.AssessmentType

    def to_mlflow_assessment(
        self, assessment_name: Optional[str] = None
    ) -> mlflow_eval.Assessment:
        """
        Convert an PerRequestAssessmentResult object to a MLflow Assessment object.
        :param assessment_name: Optional assessment name override. If present, the output assessment will use this name instead of the name in `assessment_result`
        :return: MLflow Assessment object
        """
        return mlflow_eval.Assessment(
            name=assessment_name or self.assessment_name,
            source=_convert_to_ai_judge_assessment_source(self.assessment_source),
            value=self.rating.categorical_value or self.rating.double_value,
            rationale=self.rating.rationale,
            error_code=self.rating.error_code,
            error_message=self.rating.error_message,
        )


def _convert_to_ai_judge_assessment_source(
    assessment_source: AssessmentSource,
) -> mlflow_eval.AssessmentSource:
    """
    Convert an AssessmentSource object to a MLflow AssessmentSource object.
    Source type is always AI_JUDGE.
    """
    return mlflow_eval.AssessmentSource(
        source_type=mlflow_eval.AssessmentSourceType.AI_JUDGE,
        source_id=assessment_source.source_id,
    )


@dataclasses.dataclass(frozen=True, eq=True)
class PerChunkAssessmentResult(AssessmentResult):
    """Holds the result of a per-chunk assessment."""

    positional_rating: PositionalRating
    assessment_type: assessment_config.AssessmentType = dataclasses.field(
        init=False, default=assessment_config.AssessmentType.RETRIEVAL
    )

    def to_mlflow_assessment(
        self, assessment_name: Optional[str] = None
    ) -> List[mlflow_eval.Assessment]:
        """
        Convert an PerChunkAssessmentResult object to MLflow Assessment objects.
        :param assessment_name: Optional assessment name override. If present, the output assessment will use this name instead of the name in `assessment_result`
        :return: a list of MLflow Assessment objects
        """
        return [
            mlflow_eval.Assessment(
                name=assessment_name or self.assessment_name,
                source=_convert_to_ai_judge_assessment_source(self.assessment_source),
                value=rating.categorical_value or rating.double_value,
                rationale=rating.rationale,
                error_code=rating.error_code,
                error_message=rating.error_message,
                metadata={_CHUNK_INDEX_KEY: position},
            )
            for position, rating in self.positional_rating.items()
        ]


@dataclasses.dataclass(frozen=True, eq=True, order=True)
class MetricResult:
    """Holds the result of a metric."""

    metric_name: str
    metric_value: Any

    @property
    def metric_full_name(self) -> str:
        """
        Get the full name of the metric.

        If the value is an assessment, and the assessment has a non-empty name, the full name will be the concatenation of
          the metric name and the assessment name;
        Otherwise, the full name will be the metric name.
        """
        if (
            isinstance(self.metric_value, mlflow_eval.Assessment)
            and self.metric_value.name
        ):
            return f"{self.metric_name}/{self.metric_value.name}"
        return self.metric_name

    def to_mlflow_assessment(self) -> Optional[mlflow_eval.Assessment]:
        """
        Helper method to extract metric value if it is a mlflow assessment
        :return: Assessment if the metric value is an assessment, otherwise None
        """
        if isinstance(self.metric_value, mlflow_eval.Assessment):
            # Rename the assessment to show the full name
            assessment_dict = self.metric_value.to_dictionary()
            assessment_dict["name"] = self.metric_full_name
            # Set the source to a dummy value if it is not set otherwise `from_dictionary` will fail
            if assessment_dict.get("source") is None:
                assessment_dict["source"] = {
                    "source_type": mlflow_eval.AssessmentSourceType.AI_JUDGE,
                    "source_id": None,
                }
            return mlflow_eval.Assessment.from_dictionary(assessment_dict)

        return None


@dataclasses.dataclass
class EvalResult:
    """Holds the result of the evaluation for an eval item."""

    eval_item: EvalItem
    assessment_results: List[AssessmentResult]

    overall_assessment: Optional[Rating]
    """Overall assessment of the eval item."""

    metric_results: List[MetricResult]
    """A collection of MetricResult."""

    def __post_init__(self):
        if not self.assessment_results:
            self.assessment_results = []
        if not self.metric_results:
            self.metric_results = []

    def __eq__(self, other):
        if not isinstance(other, EvalResult):
            return False
        # noinspection PyTypeChecker
        return (
            self.eval_item == other.eval_item
            and sorted(self.assessment_results) == sorted(other.assessment_results)
            and self.overall_assessment == other.overall_assessment
            and sorted(self.metric_results) == sorted(other.metric_results)
        )

    def get_metrics_dict(self) -> Dict[str, Any]:
        """Get the metrics as a dictionary. Keys are defined in schemas."""
        metrics: Dict[str, Any] = {
            metric.metric_full_name: metric.metric_value
            for metric in self.metric_results
            # Exclude assessments from metrics, they are handled in get_assessment_results_dict
            if not isinstance(metric.metric_value, mlflow_eval.Assessment)
        }
        # Remove None values in metrics
        return collection_utils.drop_none_values(metrics)

    def get_assessment_results_dict(self) -> Dict[str, schemas.ASSESSMENT_RESULT_TYPE]:
        """Get the assessment results as a dictionary. Keys are defined in schemas."""
        assessments: Dict[str, schemas.ASSESSMENT_RESULT_TYPE] = {}
        for assessment in self.assessment_results:
            # TODO(ML-45046): remove assessment type lookup in harness, rely on service
            # Get the assessment type from the built-in metrics. If the metric is not found, use the provided assessment type.
            try:
                builtin_assessment_config = assessment_config.get_builtin_assessment_config_with_service_assessment_name(
                    assessment.assessment_name
                )
                assessment_type = builtin_assessment_config.assessment_type
            except ValueError:
                assessment_type = assessment.assessment_type

            if (
                isinstance(assessment, PerRequestAssessmentResult)
                and assessment_type == assessment_config.AssessmentType.RETRIEVAL_LIST
            ):
                if assessment.rating.categorical_value is not None:
                    assessments[
                        schemas.get_retrieval_llm_rating_col_name(
                            assessment.assessment_name, is_per_chunk=False
                        )
                    ] = assessment.rating.categorical_value
                if assessment.rating.rationale is not None:
                    assessments[
                        schemas.get_retrieval_llm_rationale_col_name(
                            assessment.assessment_name, is_per_chunk=False
                        )
                    ] = assessment.rating.rationale
                if assessment.rating.error_message is not None:
                    assessments[
                        schemas.get_retrieval_llm_error_message_col_name(
                            assessment.assessment_name, is_per_chunk=False
                        )
                    ] = assessment.rating.error_message
            elif isinstance(assessment, PerRequestAssessmentResult):
                if assessment.rating.categorical_value is not None:
                    assessments[
                        schemas.get_response_llm_rating_col_name(
                            assessment.assessment_name
                        )
                    ] = assessment.rating.categorical_value
                if assessment.rating.rationale is not None:
                    assessments[
                        schemas.get_response_llm_rationale_col_name(
                            assessment.assessment_name
                        )
                    ] = assessment.rating.rationale
                if assessment.rating.error_message is not None:
                    assessments[
                        schemas.get_response_llm_error_message_col_name(
                            assessment.assessment_name
                        )
                    ] = assessment.rating.error_message
            elif isinstance(assessment, PerChunkAssessmentResult):
                # Convert the positional_rating to a list of ratings ordered by position
                # For missing positions, use an error rating. This should not happen in practice.
                ratings_ordered_by_position: List[Rating] = (
                    collection_utils.position_map_to_list(
                        assessment.positional_rating,
                        default=Rating.error("Missing rating"),
                    )
                )
                if any(
                    rating.categorical_value is not None
                    for rating in ratings_ordered_by_position
                ):
                    assessments[
                        schemas.get_retrieval_llm_rating_col_name(
                            assessment.assessment_name
                        )
                    ] = [
                        rating.categorical_value
                        for rating in ratings_ordered_by_position
                    ]
                if any(
                    rating.rationale is not None
                    for rating in ratings_ordered_by_position
                ):
                    assessments[
                        schemas.get_retrieval_llm_rationale_col_name(
                            assessment.assessment_name
                        )
                    ] = [rating.rationale for rating in ratings_ordered_by_position]
                if any(
                    rating.error_message is not None
                    for rating in ratings_ordered_by_position
                ):
                    assessments[
                        schemas.get_retrieval_llm_error_message_col_name(
                            assessment.assessment_name
                        )
                    ] = [rating.error_message for rating in ratings_ordered_by_position]
        for metric in self.metric_results:
            if isinstance(metric.metric_value, mlflow_eval.Assessment):
                if metric.metric_value.value is not None:
                    assessments[f"{metric.metric_full_name}/value"] = (
                        metric.metric_value.value
                    )

                if metric.metric_value.rationale is not None:
                    assessments[f"{metric.metric_full_name}/rationale"] = (
                        metric.metric_value.rationale
                    )

                if metric.metric_value.error_message is not None:
                    assessments[f"{metric.metric_full_name}/error_message"] = (
                        metric.metric_value.error_message
                    )

                if metric.metric_value.error_code is not None:
                    assessments[f"{metric.metric_full_name}/error_code"] = (
                        metric.metric_value.error_code
                    )
        return assessments

    def get_overall_assessment_dict(self) -> Dict[str, schemas.ASSESSMENT_RESULT_TYPE]:
        """Get the overall assessment as a dictionary. Keys are defined in schemas."""
        result = {}
        if (
            self.overall_assessment
            and self.overall_assessment.categorical_value is not None
        ):
            result[schemas.OVERALL_ASSESSMENT_RATING_COL] = (
                self.overall_assessment.categorical_value
            )
        if self.overall_assessment and self.overall_assessment.rationale is not None:
            result[schemas.OVERALL_ASSESSMENT_RATIONALE_COL] = (
                self.overall_assessment.rationale
            )
        return result

    def to_pd_series(self) -> pd.Series:
        """Converts the EvalResult to a flattened pd.Series."""
        inputs = self.eval_item.as_dict()
        assessments = self.get_assessment_results_dict()
        metrics = self.get_metrics_dict()
        overall_assessment = self.get_overall_assessment_dict()

        # Merge dictionaries and convert to pd.Series
        combined_data = {**inputs, **overall_assessment, **assessments, **metrics}
        return pd.Series(combined_data)
