import json
import logging
import warnings
from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    TypedDict,
    Union,
)

import pandas as pd
import requests
from urllib3.util import retry

from databricks import version
from databricks.agents.utils.mlflow_utils import get_workspace_url
from databricks.rag_eval import entities, env_vars, rest_entities, schemas, session
from databricks.rag_eval.clients import databricks_api_client
from databricks.rag_eval.clients.managedevals import dataset_utils
from databricks.rag_eval.utils import (
    NO_CHANGE,
    collection_utils,
    input_output_utils,
    request_utils,
)

SESSION_ID_HEADER = "managed-evals-session-id"
CLIENT_VERSION_HEADER = "managed-evals-client-version"
SYNTHETIC_GENERATION_NUM_DOCS_HEADER = "managed-evals-synthetic-generation-num-docs"
SYNTHETIC_GENERATION_NUM_EVALS_HEADER = "managed-evals-synthetic-generation-num-evals"
USE_NOTEBOOK_CLUSTER_ID = False
# When using batch endpoints, limit batches to this size, in bytes
# Technically 1MB = 1048576 bytes, but we leave 48kB for overhead of HTTP headers/other json fluff.
_BATCH_SIZE_LIMIT = 1_000_000
# When using batch endpoints, limit batches to this number of rows.
# The service has a hard limit at 2K nodes updated per request; sometimes 1 row is more than 1 node.
_BATCH_QUANTITY_LIMIT = 100
# Default page size when doing paginated requests.
_DEFAULT_PAGE_SIZE = 500

TagType = TypedDict("TagType", {"tag_name": str, "tag_id": str})

_logger = logging.getLogger(__name__)


def get_default_retry_config():
    return retry.Retry(
        total=env_vars.AGENT_EVAL_GENERATE_EVALS_MAX_RETRIES.get(),
        backoff_factor=env_vars.AGENT_EVAL_GENERATE_EVALS_BACKOFF_FACTOR.get(),
        status_forcelist=[429, 500, 502, 503, 504],
        backoff_jitter=env_vars.AGENT_EVAL_GENERATE_EVALS_BACKOFF_JITTER.get(),
        allowed_methods=frozenset(
            ["GET", "POST"]
        ),  # by default, it doesn't retry on POST
    )


def get_batch_edit_retry_config():
    return retry.Retry(
        total=3,
        backoff_factor=10,  # Retry after 0, 10, 20, 40... seconds.
        status_forcelist=[
            429
        ],  # Adding lots of evals in a row can result in rate limiting errors
        allowed_methods=["POST"],  # POST not retried by default
    )


def _raise_for_status(resp: requests.Response) -> None:
    """
    Raise an Exception if the response is an error.
    Custom error message is extracted from the response JSON.
    """
    if resp.status_code == requests.codes.ok:
        return
    http_error_msg = ""
    if 400 <= resp.status_code < 500:
        http_error_msg = (
            f"{resp.status_code} Client Error: {resp.reason}\n{resp.text}. "
        )
    elif 500 <= resp.status_code < 600:
        http_error_msg = (
            f"{resp.status_code} Server Error: {resp.reason}\n{resp.text}. "
        )
    raise requests.HTTPError(http_error_msg, response=resp)


def _get_default_headers() -> Dict[str, str]:
    """
    Constructs the default request headers.
    """
    headers = {
        CLIENT_VERSION_HEADER: version.VERSION,
    }

    return request_utils.add_traffic_id_header(headers)


def _get_synthesis_headers() -> Dict[str, str]:
    """
    Constructs the request headers for synthetic generation.
    """
    eval_session = session.current_session()
    if eval_session is None:
        return {}
    return request_utils.add_traffic_id_header(
        {
            CLIENT_VERSION_HEADER: version.VERSION,
            SESSION_ID_HEADER: eval_session.session_id,
            SYNTHETIC_GENERATION_NUM_DOCS_HEADER: str(
                eval_session.synthetic_generation_num_docs
            ),
            SYNTHETIC_GENERATION_NUM_EVALS_HEADER: str(
                eval_session.synthetic_generation_num_evals
            ),
        }
    )


class ManagedEvalsClient(databricks_api_client.DatabricksAPIClient):
    """
    Client to interact with the managed-evals service.
    """

    def __init__(self):
        super().__init__(version="2.0")

    # override from DatabricksAPIClient
    def get_default_request_session(self, *args, **kwargs):
        session = super().get_default_request_session(*args, **kwargs)
        if USE_NOTEBOOK_CLUSTER_ID:
            from pyspark.sql import SparkSession

            spark = SparkSession.builder.getOrCreate()
            cluster_id = spark.conf.get("spark.databricks.clusterUsageTags.clusterId")
            session.params = {"compute_cluster_id": cluster_id}
        return session

    def gracefully_batch_post(
        self,
        url: str,
        all_items: Sequence[Any],
        request_body_create: Callable[[Iterable[Any]], Any],
        response_body_read: Callable[[Any], Iterable[Any]],
    ):
        with self.get_default_request_session(
            headers=_get_default_headers(),
            retry_config=get_batch_edit_retry_config(),
        ) as session:
            return_values = []
            for batch in collection_utils.safe_batch(
                all_items,
                batch_byte_limit=_BATCH_SIZE_LIMIT,
                batch_quantity_limit=_BATCH_QUANTITY_LIMIT,
            ):
                request_body = request_body_create(batch)

                response = session.post(url=url, json=request_body)
                try:
                    _raise_for_status(response)
                except requests.HTTPError as e:
                    _logger.error(
                        f"Created {len(return_values)}/{len(all_items)} items before encountering an error.\n"
                        f"Returning successfully created items; please take care to avoid double-creating objects.\n{e}"
                    )
                    return return_values
                return_values.extend(response_body_read(response))
            return return_values

    def generate_questions(
        self,
        *,
        doc: entities.Document,
        num_questions: int,
        agent_description: Optional[str],
        question_guidelines: Optional[str],
    ) -> List[entities.SyntheticQuestion]:
        """
        Generate synthetic questions for the given document.
        """
        request_json = {
            "doc_content": doc.content,
            "num_questions": num_questions,
            "agent_description": agent_description,
            "question_guidelines": question_guidelines,
        }
        with self.get_default_request_session(
            get_default_retry_config(),
            headers=_get_synthesis_headers(),
        ) as session:
            resp = session.post(
                url=self.get_method_url("/managed-evals/generate-questions"),
                json=request_json,
            )

        _raise_for_status(resp)

        response_json = resp.json()
        if "questions_with_context" not in response_json or "error" in response_json:
            raise ValueError(f"Invalid response: {response_json}")
        return [
            entities.SyntheticQuestion(
                question=question_with_context["question"],
                source_doc_uri=doc.doc_uri,
                source_context=question_with_context["context"],
            )
            for question_with_context in response_json["questions_with_context"]
        ]

    def generate_answer(
        self,
        *,
        question: entities.SyntheticQuestion,
        answer_types: Collection[entities.SyntheticAnswerType],
    ) -> entities.SyntheticAnswer:
        """
        Generate synthetic answer for the given question.
        """
        request_json = {
            "question": question.question,
            "context": question.source_context,
            "answer_types": [str(answer_type) for answer_type in answer_types],
        }

        with self.get_default_request_session(
            get_default_retry_config(),
            headers=_get_synthesis_headers(),
        ) as session:
            resp = session.post(
                url=self.get_method_url("/managed-evals/generate-answer"),
                json=request_json,
            )

        _raise_for_status(resp)

        response_json = resp.json()
        return entities.SyntheticAnswer(
            question=question,
            synthetic_ground_truth=response_json.get("synthetic_ground_truth"),
            synthetic_grading_notes=response_json.get("synthetic_grading_notes"),
            synthetic_minimal_facts=response_json.get("synthetic_minimal_facts"),
        )

    def create_managed_evals_instance(
        self,
        *,
        instance_id: str,
        agent_name: Optional[str] = None,
        agent_serving_endpoint: Optional[str] = None,
        experiment_ids: Optional[Iterable[str]] = None,
    ) -> entities.EvalsInstance:
        """
        Creates a new Managed Evals instance.

        Args:
            instance_id: Managed Evals instance ID.
            agent_name: (optional) The name of the agent.
            agent_serving_endpoint: (optional) The name of the model serving endpoint that serves the agent.
            experiment_ids: (optional) The experiment IDs to associate with the instance.

        Returns:
            The created EvalsInstance.
        """
        evals_instance = entities.EvalsInstance(
            agent_name=agent_name,
            agent_serving_endpoint=agent_serving_endpoint,
            experiment_ids=experiment_ids if experiment_ids is not None else [],
        )
        request_body = {"instance": evals_instance.to_json()}
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(f"/managed-evals/instances/{instance_id}"),
                json=request_body,
            )
        _raise_for_status(response)
        return entities.EvalsInstance.from_json(response.json())

    def delete_managed_evals_instance(self, instance_id: str) -> None:
        """
        Deletes a Managed Evals instance.

        Args:
            instance_id: Managed Evals instance ID.
        """
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.delete(
                url=self.get_method_url(f"/managed-evals/instances/{instance_id}"),
            )
        _raise_for_status(response)

    def update_managed_evals_instance(
        self,
        *,
        instance_id: str,
        agent_name: Optional[str] = NO_CHANGE,
        agent_serving_endpoint: Optional[str] = NO_CHANGE,
        experiment_ids: List[str] = NO_CHANGE,
    ) -> entities.EvalsInstance:
        """
        Updates a Managed Evals instance.

        Args:
            instance_id: Managed Evals instance ID.
            agent_name: (optional) The name of the agent.
            agent_serving_endpoint: (optional) The name of the model serving endpoint that serves the agent.
            experiment_ids: (optional) The experiment IDs to associate with the instance.

        Returns:
            The updated EvalsInstance.
        """
        evals_instance = entities.EvalsInstance(
            agent_name=agent_name,
            agent_serving_endpoint=agent_serving_endpoint,
            experiment_ids=experiment_ids,
        )
        request_body = {
            "instance": evals_instance.to_json(),
            "update_mask": evals_instance.get_update_mask(),
        }

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.patch(
                url=self.get_method_url(f"/managed-evals/instances/{instance_id}"),
                json=request_body,
            )
        _raise_for_status(response)
        return entities.EvalsInstance.from_json(response.json())

    def get_managed_evals_instance(self, instance_id: str) -> entities.EvalsInstance:
        """
        Gets a Managed Evals instance.

        Args:
            instance_id: Managed Evals instance ID.

        Returns:
            The EvalsInstance.
        """
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(
                url=self.get_method_url(
                    f"/managed-evals/instances/{instance_id}/configuration"
                ),
            )
        _raise_for_status(response)
        return entities.EvalsInstance.from_json(response.json())

    def sync_evals_to_uc(self, instance_id: str):
        """
        Syncs evals from the evals table to a user-visible UC table.

        Args:
            instance_id: Managed Evals instance ID.
        """
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/instances/{instance_id}/evals:sync"
                ),
            )
        _raise_for_status(response)

    def add_evals(
        self,
        *,
        instance_id: str,
        evals: List[Dict],
    ) -> List[str]:
        """
        Add evals to the evals table.

        Args:
            instance_id: The name of the evals table.
            evals: The evals to add to the evals table.

        Returns:
            The eval IDs of the created evals.
        """
        evals = [
            {
                "request_id": e.get(schemas.REQUEST_ID_COL),
                "source_type": e.get(schemas.SOURCE_TYPE_COL),
                "source_id": e.get(schemas.SOURCE_ID_COL),
                "json_serialized_request": json.dumps(
                    input_output_utils.to_chat_completion_request(
                        e.get(schemas.REQUEST_COL)
                    ),
                ),
                "expected_response": e.get(schemas.EXPECTED_RESPONSE_COL),
                "expected_facts": [
                    {"fact": fact} for fact in e.get(schemas.EXPECTED_FACTS_COL, [])
                ],
                "expected_retrieved_context": e.get(
                    schemas.EXPECTED_RETRIEVED_CONTEXT_COL
                ),
                "tag_ids": e.get("tag_ids", []),
                "review_status": e.get("review_status"),
            }
            for e in evals
        ]
        return self.gracefully_batch_post(
            url=self.get_method_url(
                f"/managed-evals/instances/{instance_id}/evals:batchCreate"
            ),
            all_items=evals,
            request_body_create=lambda batch: {"evals": batch},
            response_body_read=lambda response: response.json().get("eval_ids", []),
        )

    def delete_evals(
        self,
        instance_id: str,
        *,
        eval_ids: List[str],
    ):
        """
        Delete evals from the evals table.
        """
        # Delete in a loop - this is inefficient but we don't have a batch delete endpoint yet.
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            for eval in eval_ids:
                response = session.delete(
                    url=self.get_method_url(
                        f"/managed-evals/instances/{instance_id}/evals/{eval}"
                    ),
                )
                _raise_for_status(response)

    def list_tags(
        self,
        instance_id: str,
    ) -> List[TagType]:
        """
        List all tags in the evals table.

        Args:
            instance_id: The name of the evals table.

        Returns:
            A list of tags.
        """
        tags = []
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            next_page_token = None
            for _ in range(100):
                response = session.get(
                    url=self.get_method_url(
                        f"/managed-evals/instances/{instance_id}/tags"
                        + (
                            ("?next_page_token=" + next_page_token)
                            if next_page_token
                            else ""
                        )
                    )
                )
                _raise_for_status(response)
                json_response = response.json()
                tags.extend(json_response.get("tags", []))
                if not (next_page_token := json_response.get("next_page_token")):
                    break
            else:
                warnings.warn(
                    "Giving up fetching tags after 100 pages of tags; potential internal error."
                )
        return tags

    def batch_create_tags(
        self,
        *,
        instance_id: str,
        tag_names: Collection[str],
    ) -> List[str]:
        """
        Call the batchCreate endpoint to create tags.

        Args:
            instance_id: The name of the evals table.
            tag_names: The tag names to create.

        Returns:
            The tag IDs of the created tags.
        """
        tag_bodies = [{"tag_name": tag} for tag in tag_names]
        return self.gracefully_batch_post(
            url=self.get_method_url(
                f"/managed-evals/instances/{instance_id}/tags:batchCreate"
            ),
            all_items=tag_bodies,
            request_body_create=lambda batch: {"tags": batch},
            response_body_read=lambda response: response.json().get("tag_ids", []),
        )

    def batch_create_eval_tags(
        self,
        instance_id: str,
        *,
        eval_tags: List[entities.EvalTag],
    ):
        """
        Batch tag evals.

        Args:
            instance_id: The name of the evals table.
            eval_tags: A list of eval-tags; each item of the list is one tag on an eval.
        """
        eval_tag_bodies = [et.to_json() for et in eval_tags]
        return self.gracefully_batch_post(
            url=self.get_method_url(
                f"/managed-evals/instances/{instance_id}/eval_tags:batchCreate"
            ),
            all_items=eval_tag_bodies,
            request_body_create=lambda batch: {"eval_tags": batch},
            response_body_read=lambda response: response.json().get("eval_tags", []),
        )

    def batch_delete_eval_tags(
        self,
        instance_id: str,
        *,
        eval_tags: List[entities.EvalTag],
    ):
        """
        Batch untag evals.

        Args:
            instance_id: The name of the evals table.
            eval_tags: A list of eval-tags; each item of the list is one tag on an eval.
        """
        eval_tag_bodies = [et.to_json() for et in eval_tags]
        return self.gracefully_batch_post(
            url=self.get_method_url(
                f"/managed-evals/instances/{instance_id}/eval_tags:batchDelete"
            ),
            all_items=eval_tag_bodies,
            request_body_create=lambda batch: {"eval_tags": batch},
            response_body_read=lambda response: response.json().get("eval_tags", []),
        )

    def update_eval_permissions(
        self,
        instance_id: str,
        *,
        add_emails: Optional[List[str]] = None,
        remove_emails: Optional[List[str]] = None,
    ):
        """Add or remove user permissions to edit an eval instance.

        Args:
            instance_id: The name of the evals table.
            add_emails: The emails to add to the permissions list.
            remove_emails: The emails to remove from the permissions list.
        """
        request_body = {"permission_change": {}}
        if add_emails:
            request_body["permission_change"]["add"] = [
                {
                    "user_email": email,
                    "permissions": ["WRITE"],
                }
                for email in add_emails
            ]
        if remove_emails:
            request_body["permission_change"]["remove"] = [
                {
                    "user_email": email,
                    "permissions": ["WRITE"],
                }
                for email in remove_emails
            ]

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/instances/{instance_id}/permissions"
                ),
                json=request_body,
            )
        _raise_for_status(response)

    def create_monitor(
        self,
        *,
        endpoint_name: str,
        monitoring_config: entities.MonitoringConfig,
        experiment_id: str,
        workspace_path: str | None,
        monitoring_table: str | None = None,
    ) -> entities.Monitor:
        pause_status: Optional[str] = None
        if monitoring_config.paused is not None:
            pause_status = (
                entities.SchedulePauseStatus.PAUSED
                if monitoring_config.paused
                else entities.SchedulePauseStatus.UNPAUSED
            ).value

        schedule_config = rest_entities.ScheduleConfig(pause_status=pause_status)
        if (
            monitoring_config.periodic
            and monitoring_config.periodic.interval
            and monitoring_config.periodic.unit
        ):
            schedule_config.periodic_schedule = rest_entities.PeriodicSchedule(
                frequency_interval=monitoring_config.periodic.interval,
                frequency_unit=monitoring_config.periodic.unit,
            )
        # Convert guidelines to NamedGuidelines format
        named_guidelines = None
        if monitoring_config.global_guidelines:
            entries = []
            for key, guidelines in monitoring_config.global_guidelines.items():
                entries.append(
                    rest_entities.NamedGuidelineEntry(key=key, guidelines=guidelines)
                )
            named_guidelines = rest_entities.NamedGuidelines(entries=entries)

        monitor_rest = rest_entities.Monitor(
            experiment_id=experiment_id,
            workspace_path=workspace_path,
            evaluation_config=rest_entities.EvaluationConfig(
                metrics=[
                    rest_entities.AssessmentConfig(name=metric)
                    for metric in monitoring_config.metrics or []
                ],
                no_metrics=(monitoring_config.metrics is not None)
                and len(monitoring_config.metrics) == 0,
                named_guidelines=named_guidelines,
            ),
            sampling=rest_entities.SamplingConfig(
                sampling_rate=monitoring_config.sample
            ),
            schedule=schedule_config,
            is_agent_external=bool(monitoring_table),
            evaluated_traces_table=monitoring_table,
        )
        request_body = {"monitor_config": monitor_rest.to_dict()}

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(f"/managed-evals/monitors/{endpoint_name}"),
                json=request_body,
            )
        _raise_for_status(response)
        monitor_info_rest = rest_entities.MonitorInfo.from_dict(
            response.json(), infer_missing=True
        )

        # If the monitoring table is provided, convert to ExternalMonitor, otherwise Monitor
        if monitoring_table:
            monitor = monitor_info_rest.to_external_monitor()
            created_msg = f'Created monitor for experiment "{monitor.experiment_id}".'
        else:
            monitor = monitor_info_rest.to_monitor()
            created_msg = f'Created monitor for endpoint "{endpoint_name}".'

        monitoring_page_url = f"{get_workspace_url()}/ml/experiments/{experiment_id}/evaluation-monitoring?endpointName={endpoint_name}"

        print(created_msg)
        print(f"\nView monitoring page: {monitoring_page_url}")

        if monitor.monitoring_config.metrics:
            print("\nComputed metrics:")
            for metric in monitor.monitoring_config.metrics:
                print(f"• {metric}")

        else:
            print(
                "\nNo computed metrics specified. To override the computed metrics, include `metrics` in the monitoring_config."
            )

        return monitor

    def update_monitor(
        self,
        *,
        endpoint_name: str,
        monitoring_config: entities.MonitoringConfig,
    ) -> entities.Monitor:
        pause_status: Optional[str] = None
        if monitoring_config.paused is not None:
            pause_status = (
                entities.SchedulePauseStatus.PAUSED
                if monitoring_config.paused
                else entities.SchedulePauseStatus.UNPAUSED
            )

        # Unless there are updates to the evaluation config, do not pass an EvaluationConfig to the service.
        evaluation_config: Optional[rest_entities.EvaluationConfig] = None

        # Update metrics if provided
        if monitoring_config.metrics is not None:
            evaluation_config = rest_entities.EvaluationConfig(
                metrics=[
                    rest_entities.AssessmentConfig(name=metric)
                    for metric in monitoring_config.metrics
                ],
                no_metrics=len(monitoring_config.metrics) == 0,
            )

        # Update guidelines if provided
        if monitoring_config.global_guidelines is not None:
            entries = []
            for key, guidelines in monitoring_config.global_guidelines.items():
                entries.append(
                    rest_entities.NamedGuidelineEntry(key=key, guidelines=guidelines)
                )
            if evaluation_config is None:
                evaluation_config = rest_entities.EvaluationConfig(
                    named_guidelines=rest_entities.NamedGuidelines(entries=entries)
                )
            else:
                evaluation_config.named_guidelines = rest_entities.NamedGuidelines(
                    entries=entries
                )

        sampling_config: Optional[rest_entities.SamplingConfig] = None
        if monitoring_config.sample:
            sampling_config = rest_entities.SamplingConfig(
                sampling_rate=monitoring_config.sample
            )

        schedule_config: Optional[rest_entities.ScheduleConfig] = None
        if monitoring_config.periodic:
            schedule_config = rest_entities.ScheduleConfig(
                pause_status=pause_status,
            )
            if monitoring_config.periodic.interval and monitoring_config.periodic.unit:
                schedule_config.periodic_schedule = rest_entities.PeriodicSchedule(
                    frequency_interval=monitoring_config.periodic.interval,
                    frequency_unit=monitoring_config.periodic.unit,
                )

        monitor_rest = rest_entities.Monitor(
            evaluation_config=evaluation_config,
            sampling=sampling_config,
            schedule=schedule_config,
        )

        request_body = {
            "monitor": monitor_rest.to_dict(),
            "update_mask": monitor_rest.get_update_mask(),
        }
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.patch(
                url=self.get_method_url(f"/managed-evals/monitors/{endpoint_name}"),
                json=request_body,
            )
        _raise_for_status(response)
        monitor_info_rest = rest_entities.MonitorInfo.from_dict(
            response.json(), infer_missing=True
        )

        if monitor_info_rest.is_external:
            monitor = monitor_info_rest.to_external_monitor()
        else:
            monitor = monitor_info_rest.to_monitor()

        monitoring_page_url = f"{get_workspace_url()}/ml/experiments/{monitor.experiment_id}/evaluation-monitoring?endpointName={endpoint_name}"
        user_message = f"""Updated monitor for endpoint "{endpoint_name}".

View monitoring page: {monitoring_page_url}"""
        print(user_message)

        return monitor

    def list_monitors(self, *, experiment_id: str) -> list[entities.Monitor]:
        """List all monitors for a given experiment.

        Args:
            experiment_id (str): The ID of the experiment that the monitors are associated with.

        Returns:
            list[entities.Monitor]: A list of monitors associated with the given experiment.
        """
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url("/managed-evals/monitors"),
                json={"experiment_id": experiment_id},
            )
        _raise_for_status(response)

        monitor_infos_raw = response.json().get("monitor_infos", [])
        monitor_infos_raw = map(
            lambda monitor_info: rest_entities.MonitorInfo.from_dict(
                monitor_info, infer_missing=True
            ),
            monitor_infos_raw,
        )
        return map(
            lambda monitor_info: (
                monitor_info.to_external_monitor()
                if monitor_info.is_external
                else monitor_info.to_monitor()
            ),
            monitor_infos_raw,
        )

    def get_monitor(
        self,
        *,
        endpoint_name: str | None = None,
        monitoring_table: str | None = None,
    ) -> entities.Monitor | entities.ExternalMonitor:
        """Call the get monitor endpoint to get information on a monitor.

        Args:
            endpoint_name (str | None, optional): The name of the endpoint. Defaults to None.
            monitoring_table (str | None, optional): The fullname of the monitoring table. Defaults to None.

        Raises:
            ValueError: When both or neither of 'endpoint_name' and 'monitoring_table' are provided.

        Returns:
            Monitor | ExternalMonitor: The monitor object. If the server notes that the monitor is
                for an external agent, returns an ExternalMonitor. Otherwise, returns a Monitor.
        """
        has_endpoint = endpoint_name is not None
        has_table = monitoring_table is not None
        if not (has_endpoint ^ has_table):
            raise ValueError(
                "Exactly one of 'endpoint_name' and 'monitoring_table' must be provided."
            )

        base_url = "monitors"
        urlpath = f"{base_url}/{endpoint_name}"
        if monitoring_table is not None:
            urlpath = f"{base_url}/table_name/{monitoring_table}/"

        request_body = {}
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(
                url=self.get_method_url(f"/managed-evals/{urlpath}"),
                json=request_body,
            )
        _raise_for_status(response)

        monitor_info_rest = rest_entities.MonitorInfo.from_dict(
            response.json(), infer_missing=True
        )

        if monitor_info_rest.is_external:
            return monitor_info_rest.to_external_monitor()
        return monitor_info_rest.to_monitor()

    def delete_monitor(self, endpoint_name: str) -> None:
        request_body = {}
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.delete(
                url=self.get_method_url(f"/managed-evals/monitors/{endpoint_name}"),
                json=request_body,
            )
        _raise_for_status(response)

    def monitoring_usage_events(
        self,
        *,
        endpoint_name: str,
        job_id: str,
        run_id: str,
        run_ended: bool,
        num_traces_evaluated: Optional[int],
        error_message: Optional[str] = None,
        additional_headers: Optional[Dict[str, str]] = None,
    ) -> None:
        """
        :param endpoint_name: Name of endpoint associated with monitor.
        :param job_id: ID of job.
        :param run_id: ID of job run.
        :param num_traces_evaluated: Number of traces evaluated in the run.
        :param run_ended: Whether this usage event is triggered by the completion of a job run, either success or failure.
        :param error_message: Error message associated with failed run. May be empty.
        :param additional_headers: Additional headers to be passed when sending the request. May be empty.
        """

        job_start = None if run_ended else {}
        job_completion = (
            rest_entities.JobCompletionEvent(
                success=error_message is None,
                error_message=error_message,
            )
            if run_ended
            else None
        )

        monitoring_event = rest_entities.MonitoringEvent(
            job_start=job_start,
            job_completion=job_completion,
        )

        request_body = {
            "job_id": job_id,
            "run_id": run_id,
            "events": [monitoring_event.to_dict()],
        }
        if num_traces_evaluated is not None:
            num_traces_evaluated_metric = rest_entities.MonitoringMetric(
                num_traces_evaluated
            )
            request_body["metrics"] = [num_traces_evaluated_metric.to_dict()]

        default_headers = _get_default_headers()
        headers = {**default_headers, **(additional_headers or {})}

        with self.get_default_request_session(headers=headers) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/monitors/{endpoint_name}/usage-logging"
                ),
                json=request_body,
            )
        _raise_for_status(response)

    ##### Review App REST APIs #####
    def create_review_app(self, review_app: entities.ReviewApp) -> entities.ReviewApp:
        review_app_rest = rest_entities.ReviewApp.from_review_app(review_app)
        request_body = review_app_rest.to_dict()
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url("/managed-evals/review-apps"),
                json=request_body,
            )
        _raise_for_status(response)
        review_app_rest: rest_entities.ReviewApp = rest_entities.ReviewApp.from_dict(
            response.json()
        )
        return review_app_rest.to_review_app()

    def _paginate_review_apps(
        self, filter: str, page_token: Optional[str]
    ) -> Tuple[list[entities.ReviewApp], Optional[str]]:
        url = self.get_method_url(f"/managed-evals/review-apps?filter={filter}")
        if page_token:
            url += f"&page_token={page_token}"

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(url=url)
        _raise_for_status(response)
        response = response.json()
        next_page_token = response.get("next_page_token")
        review_apps: list[entities.ReviewApp] = []
        for app_dict in response.get("review_apps", []):
            review_apps.append(
                rest_entities.ReviewApp.from_dict(app_dict).to_review_app()
            )
        return review_apps, next_page_token

    def list_review_apps(self, filter: str) -> list[entities.ReviewApp]:
        # Url encode the filter string
        filter = requests.utils.quote(filter)
        next_page_token = None
        all_review_apps: list[entities.ReviewApp] = []
        while True:
            review_apps, next_page_token = self._paginate_review_apps(
                filter, next_page_token
            )
            all_review_apps.extend(review_apps)
            if not next_page_token:
                break

        return all_review_apps

    def update_review_app(
        self, review_app: entities.ReviewApp, update_mask: str
    ) -> entities.ReviewApp:
        review_app_rest = rest_entities.ReviewApp.from_review_app(review_app)
        request_body = review_app_rest.to_dict()
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.patch(
                url=self.get_method_url(
                    f"/managed-evals/review-apps/{review_app.review_app_id}?update_mask={update_mask}"
                ),
                json=request_body,
            )
        _raise_for_status(response)
        review_app_rest: rest_entities.ReviewApp = rest_entities.ReviewApp.from_dict(
            response.json()
        )
        return review_app_rest.to_review_app()

    def create_labeling_session(
        self,
        review_app: entities.ReviewApp,
        name: str,
        *,
        # Must be workspace users for now due to ACL
        assigned_users: list[str],
        # agent names must already be added to the backend.
        agent: Optional[str],
        # the schema names, must be already added to backend.
        label_schemas: list[Union[str, entities.LabelSchema]],
    ) -> entities.LabelingSession:
        # Validate label schemas.
        label_schema_names: list[str] = []
        for s in label_schemas:
            if isinstance(s, entities.LabelSchema):
                label_schema_names.append(s.name)
            elif isinstance(s, str):
                label_schema_names.append(s)
            else:
                raise ValueError(
                    f"Invalid type for label_schemas: {type(s)}. Must be str or LabelSchema."
                )

        labeling_session_rest = rest_entities.LabelingSession(
            labeling_session_id=None,  # Not yet created.
            mlflow_run_id=None,  # Not yet created.
            name=name,
            assigned_users=assigned_users,
            agent=rest_entities.AgentRef(agent_name=agent) if agent else None,
            labeling_schemas=[
                rest_entities.LabelingSchemaRef(name=name)
                for name in label_schema_names
            ],
        )
        request_body = labeling_session_rest.to_dict()
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/review-apps/{review_app.review_app_id}/labeling-sessions"
                ),
                json=request_body,
            )
        _raise_for_status(response)
        labeling_session_rest: rest_entities.LabelingSession = (
            rest_entities.LabelingSession.from_dict(response.json())
        )
        session_url = f"{review_app.url}/tasks/labeling/{labeling_session_rest.labeling_session_id}"
        return labeling_session_rest.to_labeling_session(
            review_app.review_app_id, review_app.experiment_id, session_url
        )

    def _paginate_labeling_sessions(
        self, review_app: entities.ReviewApp, page_token: Optional[str]
    ) -> Tuple[list[entities.LabelingSession], Optional[str]]:
        url = self.get_method_url(
            f"/managed-evals/review-apps/{review_app.review_app_id}/labeling-sessions?page_size={_DEFAULT_PAGE_SIZE}"
        )
        if page_token:
            url += f"&page_token={page_token}"

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(url=url)
        _raise_for_status(response)
        response = response.json()
        next_page_token = response.get("next_page_token")
        sessions: list[entities.LabelingSession] = []
        for session_dict in response.get("labeling_sessions", []):
            rest_session: rest_entities.LabelingSession = (
                rest_entities.LabelingSession.from_dict(session_dict)
            )
            session_url = (
                f"{review_app.url}/tasks/labeling/{rest_session.labeling_session_id}"
            )
            sessions.append(
                rest_session.to_labeling_session(
                    review_app.review_app_id, review_app.experiment_id, session_url
                )
            )
        return sessions, next_page_token

    def list_labeling_sessions(
        self, review_app: entities.ReviewApp
    ) -> list[entities.LabelingSession]:
        next_page_token = None
        all_sessions: list[entities.LabelingSession] = []
        while True:
            sessions, next_page_token = self._paginate_labeling_sessions(
                review_app, next_page_token
            )
            all_sessions.extend(sessions)
            if not next_page_token:
                break

        return all_sessions

    def _paginate_labeling_items(
        self, review_app_id: str, labeling_session_id: str, page_token: Optional[str]
    ) -> Tuple[list[rest_entities.Item], Optional[str]]:
        url = self.get_method_url(
            f"/managed-evals/review-apps/{review_app_id}/labeling-sessions/{labeling_session_id}/items?page_size={_DEFAULT_PAGE_SIZE}"
        )
        if page_token:
            url += f"&page_token={page_token}"

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(url=url)
        _raise_for_status(response)
        response = response.json()
        next_page_token = response.get("next_page_token")
        items = [
            rest_entities.Item.from_dict(item) for item in response.get("items", [])
        ]
        return items, next_page_token

    def list_items_in_labeling_session(
        self, labeling_session: entities.LabelingSession
    ) -> list[rest_entities.Item]:
        next_page_token = None
        all_items: list[rest_entities.Item] = []
        while True:
            items, next_page_token = self._paginate_labeling_items(
                labeling_session.review_app_id,
                labeling_session.labeling_session_id,
                next_page_token,
            )
            all_items.extend(items)
            if not next_page_token:
                break

        return all_items

    def update_labeling_session(
        self,
        labeling_session: entities.LabelingSession,
        update_mask: str,
    ) -> entities.LabelingSession:
        labeling_session_rest = rest_entities.LabelingSession(
            labeling_session_id=labeling_session.labeling_session_id,
            mlflow_run_id=labeling_session.mlflow_run_id,
            name=labeling_session.name,
            assigned_users=labeling_session.assigned_users,
            agent=rest_entities.AgentRef(agent_name=labeling_session.agent),
            labeling_schemas=[
                rest_entities.LabelingSchemaRef(name=schema_name)
                for schema_name in labeling_session.label_schemas
            ],
        )
        request_body = labeling_session_rest.to_dict()
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.patch(
                url=self.get_method_url(
                    f"/managed-evals/review-apps/{labeling_session.review_app_id}/labeling-sessions/{labeling_session.labeling_session_id}?update_mask={update_mask}"
                ),
                json=request_body,
            )
        _raise_for_status(response)
        labeling_session_rest: rest_entities.LabelingSession = (
            rest_entities.LabelingSession.from_dict(response.json())
        )
        return labeling_session_rest.to_labeling_session(
            labeling_session.review_app_id,
            labeling_session.experiment_id,
            labeling_session.url,
        )

    def delete_labeling_session(
        self, review_app_id: str, labeling_session_id: str
    ) -> None:
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.delete(
                url=self.get_method_url(
                    f"/managed-evals/review-apps/{review_app_id}/labeling-sessions/{labeling_session_id}"
                ),
            )
        _raise_for_status(response)

    def batch_create_items_in_labeling_session(
        self,
        labeling_session: entities.LabelingSession,
        trace_ids: Optional[list[str]] = None,
        dataset_id: Optional[str] = None,
        dataset_record_ids: Optional[list[str]] = None,
    ) -> None:
        items = []
        review_app_id = labeling_session.review_app_id
        labeling_session_id = labeling_session.labeling_session_id
        if trace_ids:
            items.extend([{"source": {"trace_id": trace_id}} for trace_id in trace_ids])

        assert (dataset_id is None and dataset_record_ids is None) or (
            dataset_id and dataset_record_ids
        ), "If dataset_id is provided, dataset_record_ids must also be provided."

        if dataset_id:
            items.extend(
                [
                    {
                        "source": {
                            "dataset_record": {
                                "dataset_id": dataset_id,
                                "dataset_record_id": record_id,
                            }
                        }
                    }
                    for record_id in dataset_record_ids
                ]
            )
        request_body = {"items": items}
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/review-apps/{review_app_id}/labeling-sessions/{labeling_session_id}/items:batchCreate"
                ),
                json=request_body,
            )
        _raise_for_status(response)

    def create_dataset(
        self, uc_table_name: str, experiment_ids: list[str]
    ) -> entities.Dataset:
        url = self.get_method_url("/managed-evals/datasets")
        if experiment_ids:
            url += f"?experiment_ids={','.join(experiment_ids)}"
        dataset_rest = rest_entities.Dataset(
            name=uc_table_name,
            source_type="databricks-uc-table",
        )
        request_body = dataset_rest.to_dict()
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=url,
                json=request_body,
            )
        _raise_for_status(response)
        # The REST and public python entities are the same for datasets.
        return entities.Dataset.from_dict(response.json())

    def get_dataset(self, dataset_id: str) -> entities.Dataset:
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(
                url=self.get_method_url(f"/managed-evals/datasets/{dataset_id}")
            )
        _raise_for_status(response)
        return entities.Dataset.from_dict(response.json())

    def delete_dataset(self, dataset_id: str) -> None:
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.delete(
                url=self.get_method_url(f"/managed-evals/datasets/{dataset_id}")
            )
        _raise_for_status(response)

    def update_dataset(
        self, dataset: entities.Dataset, update_mask: str
    ) -> entities.Dataset:
        # The REST and public python entities are the same.
        request_body = dataset.to_dict()
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.patch(
                url=self.get_method_url(
                    f"/managed-evals/datasets/{dataset.dataset_id}?update_mask={update_mask}"
                ),
                json=request_body,
            )
        _raise_for_status(response)
        # The REST and public python entities are the same.
        return entities.Dataset.from_dict(response.json())

    def _paginate_dataset_records(
        self, dataset_id: str, page_token: Optional[str]
    ) -> Tuple[list[rest_entities.DatasetRecord], Optional[str]]:
        url = self.get_method_url(
            f"/managed-evals/datasets/{dataset_id}/records?page_size={_DEFAULT_PAGE_SIZE}"
        )
        if page_token:
            url += f"&page_token={page_token}"

        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.get(url=url)
        _raise_for_status(response)
        response = response.json()
        next_page_token = response.get("next_page_token")
        records = [
            rest_entities.DatasetRecord.from_dict(record)
            for record in response.get("dataset_records", [])
        ]
        return records, next_page_token

    def list_dataset_records(
        self, dataset_id: str
    ) -> list[rest_entities.DatasetRecord]:
        next_page_token = None
        all_records: list[rest_entities.DatasetRecord] = []
        while True:
            records, next_page_token = self._paginate_dataset_records(
                dataset_id, next_page_token
            )
            all_records.extend(records)
            if not next_page_token:
                break

        return all_records

    def batch_create_dataset_records(
        self,
        uc_table_name: str,
        dataset_id: str,
        records: list[rest_entities.DatasetRecord],
    ) -> None:
        request_body = {
            "requests": [
                {"dataset_id": dataset_id, "dataset_record": r.to_dict()}
                for r in records
            ]
        }
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/datasets/{dataset_id}/records:batchCreate"
                ),
                json=request_body,
            )
        _raise_for_status(response)

        # Sync to UC.
        dataset_rows = [
            entities.DatasetRow.from_rest_dataset_record(record)
            for record in self.list_dataset_records(dataset_id)
        ]
        df = pd.DataFrame.from_records([row.to_dict() for row in dataset_rows])
        dataset_utils.sync_dataset_to_uc(uc_table_name, df)

    def upsert_dataset_record_expectations(
        self,
        uc_table_name: str,
        dataset_id: str,
        dataset_record_id: str,
        expectations: dict[str, rest_entities.Expectation],
    ) -> None:
        request_body = {
            "expectations": {
                key: value.to_dict() for key, value in expectations.items()
            }
        }
        with self.get_default_request_session(
            headers=_get_default_headers()
        ) as session:
            response = session.post(
                url=self.get_method_url(
                    f"/managed-evals/datasets/{dataset_id}/records/{dataset_record_id}/expectations"
                ),
                json=request_body,
            )
        _raise_for_status(response)
        # Sync to UC.
        dataset_rows = [
            entities.DatasetRow.from_rest_dataset_record(record)
            for record in self.list_dataset_records(dataset_id)
        ]
        df = pd.DataFrame.from_records([row.to_dict() for row in dataset_rows])
        dataset_utils.sync_dataset_to_uc(uc_table_name, df)
