# Generated by the protocol buffer compiler.  DO NOT EDIT!
# sources: chat.proto
# plugin: python-betterproto
# This file has been @generated

from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    AsyncIterator,
    Dict,
    List,
    Optional,
)

import betterproto
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
import grpclib
from betterproto.grpc.grpclib_server import ServiceBase

from .. import (
    application as _application__,
    key as _key__,
    metric as _metric__,
    sentinel as _sentinel__,
)


if TYPE_CHECKING:
    import grpclib.server
    from betterproto.grpc.grpclib_client import MetadataLike
    from grpclib.metadata import Deadline


class EvaluationProvider(betterproto.Enum):
    GROQ = 0
    OPENAI = 1


class EvaluationContentType(betterproto.Enum):
    MESSAGE = 0
    TEXT = 1
    PARTIAL = 2


@dataclass(eq=False, repr=False)
class ChatCompletionRequest(betterproto.Message):
    params: "ChatCompletionParams" = betterproto.message_field(1)
    application_ref_name: str = betterproto.string_field(2)
    session_id: str = betterproto.string_field(3)
    reference_id: str = betterproto.string_field(4)
    action_type: str = betterproto.string_field(5)
    company_id: Optional[int] = betterproto.int64_field(
        6, optional=True, group="_company_id"
    )
    apply_corrections: bool = betterproto.bool_field(7)
    id: int = betterproto.int64_field(8)
    date_created: Optional[int] = betterproto.int64_field(
        9, optional=True, group="_date_created"
    )
    request_id: Optional[str] = betterproto.string_field(
        10, optional=True, group="_request_id"
    )
    evaluation_enabled: bool = betterproto.bool_field(11)
    auth_info: "ChatCompletionAuth" = betterproto.message_field(12)
    return_evaluation: bool = betterproto.bool_field(13)
    error: str = betterproto.string_field(14)
    auth_keys: "_key__.KeyMap" = betterproto.message_field(15)
    inference_location: str = betterproto.string_field(16)
    context_retrieval_enabled: bool = betterproto.bool_field(17)
    context_query: str = betterproto.string_field(18)
    context: str = betterproto.string_field(19)
    return_request: bool = betterproto.bool_field(20)
    fallback_model: str = betterproto.string_field(21)
    user_id: str = betterproto.string_field(22)
    assistant: bool = betterproto.bool_field(23)
    client_params: "ClientParams" = betterproto.message_field(24)
    metadata: Dict[str, str] = betterproto.map_field(
        25, betterproto.TYPE_STRING, betterproto.TYPE_STRING
    )


@dataclass(eq=False, repr=False)
class ClientParams(betterproto.Message):
    base_url: str = betterproto.string_field(1)
    default_headers: Dict[str, str] = betterproto.map_field(
        2, betterproto.TYPE_STRING, betterproto.TYPE_STRING
    )
    default_query: Dict[str, str] = betterproto.map_field(
        3, betterproto.TYPE_STRING, betterproto.TYPE_STRING
    )


@dataclass(eq=False, repr=False)
class ChatCompletionAuth(betterproto.Message):
    openai_api_key: str = betterproto.string_field(1)
    groq_api_key: str = betterproto.string_field(2)
    override_api_key: str = betterproto.string_field(3)
    anthropic_api_key: str = betterproto.string_field(4)


@dataclass(eq=False, repr=False)
class ChatCompletionParams(betterproto.Message):
    messages: List["ChatMessage"] = betterproto.message_field(1)
    model: str = betterproto.string_field(2)
    max_tokens: int = betterproto.int32_field(3)
    response_format: "ResponseFormat" = betterproto.message_field(4)
    temperature: float = betterproto.double_field(5)
    top_p: float = betterproto.double_field(6)
    stop: str = betterproto.string_field(7)
    logprobs: bool = betterproto.bool_field(8)
    top_logprobs: int = betterproto.int32_field(9)
    n: int = betterproto.int32_field(10)
    seed: int = betterproto.int32_field(11)
    stream: bool = betterproto.bool_field(12)
    stream_options: "StreamOptions" = betterproto.message_field(13)
    logit_bias: Dict[str, float] = betterproto.map_field(
        14, betterproto.TYPE_STRING, betterproto.TYPE_DOUBLE
    )
    presence_penalty: float = betterproto.double_field(15)
    frequency_penalty: float = betterproto.double_field(16)
    user: str = betterproto.string_field(17)
    tools: List["Tool"] = betterproto.message_field(18)
    tool_choice: str = betterproto.string_field(19)
    parallel_tool_calls: bool = betterproto.bool_field(20)


@dataclass(eq=False, repr=False)
class ResponseFormat(betterproto.Message):
    type: str = betterproto.string_field(1)
    json_schema: Optional["JsonSchema"] = betterproto.message_field(
        2, optional=True, group="_json_schema"
    )


@dataclass(eq=False, repr=False)
class StreamOptions(betterproto.Message):
    include_usage: bool = betterproto.bool_field(1)


@dataclass(eq=False, repr=False)
class Tool(betterproto.Message):
    type: str = betterproto.string_field(1)
    function: "Function" = betterproto.message_field(2)


@dataclass(eq=False, repr=False)
class Function(betterproto.Message):
    name: str = betterproto.string_field(1)
    description: str = betterproto.string_field(2)
    parameters: Dict[str, str] = betterproto.map_field(
        3, betterproto.TYPE_STRING, betterproto.TYPE_STRING
    )
    strict: bool = betterproto.bool_field(4)


@dataclass(eq=False, repr=False)
class JsonSchema(betterproto.Message):
    name: str = betterproto.string_field(1)
    strict: bool = betterproto.bool_field(2)
    schema: Dict[str, str] = betterproto.map_field(
        3, betterproto.TYPE_STRING, betterproto.TYPE_STRING
    )


@dataclass(eq=False, repr=False)
class ChatCompletionResponse(betterproto.Message):
    id: str = betterproto.string_field(1)
    response_id: int = betterproto.int64_field(2)
    object: str = betterproto.string_field(3)
    created: int = betterproto.int64_field(4)
    model: str = betterproto.string_field(5)
    system_fingerprint: str = betterproto.string_field(6)
    choices: List["ChatCompletionChoice"] = betterproto.message_field(7)
    usage: "ChatCompletionUsage" = betterproto.message_field(8)
    request_id: str = betterproto.string_field(9)
    evaluate_response: Optional["EvaluateResponse"] = betterproto.message_field(
        10, optional=True, group="_evaluate_response"
    )
    correction_applied: bool = betterproto.bool_field(11)
    first_token_time: int = betterproto.int64_field(12)
    response_time: int = betterproto.int64_field(13)
    chat_completion_request: Optional["ChatCompletionRequest"] = (
        betterproto.message_field(14, optional=True, group="_chat_completion_request")
    )
    service_tier: Optional[str] = betterproto.string_field(
        15, optional=True, group="_service_tier"
    )
    input_safety_score: float = betterproto.float_field(16)
    fallback_reason: str = betterproto.string_field(17)


@dataclass(eq=False, repr=False)
class ChatCompletionChoice(betterproto.Message):
    index: int = betterproto.int32_field(1)
    message: "ChatCompletionMessage" = betterproto.message_field(2)
    logprobs: "LogProbs" = betterproto.message_field(3)
    finish_reason: str = betterproto.string_field(4)
    is_correction: bool = betterproto.bool_field(5)


@dataclass(eq=False, repr=False)
class ChatMessage(betterproto.Message):
    role: str = betterproto.string_field(1)
    content: Optional[str] = betterproto.string_field(
        2, optional=True, group="_content"
    )
    name: Optional[str] = betterproto.string_field(3, optional=True, group="_name")
    tool_calls: List["ToolCall"] = betterproto.message_field(4)
    tool_call_id: Optional[str] = betterproto.string_field(
        5, optional=True, group="_tool_call_id"
    )


@dataclass(eq=False, repr=False)
class FunctionCall(betterproto.Message):
    arguments: str = betterproto.string_field(1)
    name: str = betterproto.string_field(2)


@dataclass(eq=False, repr=False)
class ChatCompletionMessage(betterproto.Message):
    role: str = betterproto.string_field(1)
    content: str = betterproto.string_field(2)
    function_call: Optional["FunctionCall"] = betterproto.message_field(
        3, optional=True, group="_function_call"
    )
    tool_calls: List["ToolCall"] = betterproto.message_field(4)
    refusal: str = betterproto.string_field(5)


@dataclass(eq=False, repr=False)
class ChatCompletionUsage(betterproto.Message):
    prompt_tokens: int = betterproto.int32_field(1)
    completion_tokens: int = betterproto.int32_field(2)
    total_tokens: int = betterproto.int32_field(3)
    cost: int = betterproto.int64_field(4)
    model: str = betterproto.string_field(5)


@dataclass(eq=False, repr=False)
class ChatCompletionChunk(betterproto.Message):
    id: str = betterproto.string_field(1)
    choices: List["ChunkChoice"] = betterproto.message_field(2)
    created: int = betterproto.int64_field(3)
    model: str = betterproto.string_field(4)
    object: str = betterproto.string_field(5)
    system_fingerprint: str = betterproto.string_field(6)
    usage: "ChatCompletionUsage" = betterproto.message_field(7)
    evaluate_response: Optional["EvaluateResponse"] = betterproto.message_field(
        8, optional=True, group="_evaluate_response"
    )
    correction_applied: bool = betterproto.bool_field(9)
    service_tier: Optional[str] = betterproto.string_field(
        10, optional=True, group="_service_tier"
    )
    input_safety_score: float = betterproto.float_field(11)
    request_id: str = betterproto.string_field(12)
    fallback_reason: str = betterproto.string_field(13)


@dataclass(eq=False, repr=False)
class ChunkChoice(betterproto.Message):
    delta: "ChoiceDelta" = betterproto.message_field(1)
    finish_reason: str = betterproto.string_field(2)
    index: int = betterproto.int32_field(3)
    logprobs: "LogProbs" = betterproto.message_field(4)
    is_correction: bool = betterproto.bool_field(5)


@dataclass(eq=False, repr=False)
class ChoiceDelta(betterproto.Message):
    content: str = betterproto.string_field(1)
    role: str = betterproto.string_field(2)
    tool_calls: List["ToolCall"] = betterproto.message_field(3)
    function_call: "FunctionCall" = betterproto.message_field(4)
    refusal: str = betterproto.string_field(6)


@dataclass(eq=False, repr=False)
class ToolCall(betterproto.Message):
    id: str = betterproto.string_field(1)
    type: str = betterproto.string_field(2)
    function: "ToolCallFunction" = betterproto.message_field(3)
    index: Optional[int] = betterproto.int64_field(4, optional=True, group="_index")


@dataclass(eq=False, repr=False)
class ToolCallFunction(betterproto.Message):
    name: str = betterproto.string_field(1)
    arguments: str = betterproto.string_field(2)


@dataclass(eq=False, repr=False)
class LogProbs(betterproto.Message):
    content: List["LogPropsContent"] = betterproto.message_field(1)
    refusal: List["LogPropsContent"] = betterproto.message_field(2)


@dataclass(eq=False, repr=False)
class LogPropsContent(betterproto.Message):
    token: str = betterproto.string_field(1)
    logprob: float = betterproto.double_field(2)
    bytes: List[int] = betterproto.int32_field(3)
    top_logprobs: List["TopLogProbs"] = betterproto.message_field(4)


@dataclass(eq=False, repr=False)
class TopLogProbs(betterproto.Message):
    token: str = betterproto.string_field(1)
    logprob: float = betterproto.double_field(2)
    bytes: List[float] = betterproto.double_field(3)


@dataclass(eq=False, repr=False)
class ChatStorageRequest(betterproto.Message):
    chat_completion_request: "ChatCompletionRequest" = betterproto.message_field(1)
    chat_completion_response: "ChatCompletionResponse" = betterproto.message_field(2)
    evaluate_request: "EvaluateRequest" = betterproto.message_field(3)
    timing_metrics: "_metric__.RequestTimingMetric" = betterproto.message_field(4)


@dataclass(eq=False, repr=False)
class EvaluateRequest(betterproto.Message):
    id: int = betterproto.int64_field(1)
    date_created: int = betterproto.int64_field(2)
    application_id: int = betterproto.int64_field(3)
    application_ref_name: str = betterproto.string_field(4)
    session_id: str = betterproto.string_field(5)
    reference_id: str = betterproto.string_field(6)
    action_type: str = betterproto.string_field(7)
    evaluation_content_type: "EvaluationContentType" = betterproto.enum_field(8)
    eval_results_set: Optional["EvaluateResponse"] = betterproto.message_field(
        9, optional=True, group="_eval_results_set"
    )
    company_id: int = betterproto.int64_field(10)
    evaluation_context: str = betterproto.string_field(11)
    text_content: str = betterproto.string_field(12)
    message_content: List["ChatMessage"] = betterproto.message_field(13)
    request_id: str = betterproto.string_field(14)
    sentinel_id: int = betterproto.int64_field(15)
    fault_description: str = betterproto.string_field(16)
    chat_completion_request: Optional["ChatCompletionRequest"] = (
        betterproto.message_field(17, optional=True, group="_chat_completion_request")
    )
    chat_completion_response: Optional["ChatCompletionResponse"] = (
        betterproto.message_field(18, optional=True, group="_chat_completion_response")
    )
    timeout: float = betterproto.float_field(19)
    partial: str = betterproto.string_field(20)
    stream: bool = betterproto.bool_field(21)
    apply_corrections: bool = betterproto.bool_field(22)
    timing_metrics: "_metric__.RequestTimingMetric" = betterproto.message_field(23)
    tools: List["Tool"] = betterproto.message_field(24)


@dataclass(eq=False, repr=False)
class EvaluateResult(betterproto.Message):
    id: int = betterproto.int64_field(1)
    status: str = betterproto.string_field(2)
    description: str = betterproto.string_field(3)
    confidence: float = betterproto.double_field(4)
    meta: Dict[str, str] = betterproto.map_field(
        5, betterproto.TYPE_STRING, betterproto.TYPE_STRING
    )
    sentinel_id: int = betterproto.int64_field(6)
    eval_time: int = betterproto.int64_field(7)
    sentinel: "_sentinel__.Sentinel" = betterproto.message_field(8)
    date_created: int = betterproto.int64_field(9)
    usage: "ChatCompletionUsage" = betterproto.message_field(10)


@dataclass(eq=False, repr=False)
class EvaluateResponse(betterproto.Message):
    application_id: int = betterproto.int64_field(1)
    session_id: str = betterproto.string_field(2)
    request_id: int = betterproto.int64_field(3)
    evaluation_results: List["EvaluateResult"] = betterproto.message_field(4)
    evaluation_request_id: str = betterproto.string_field(5)
    evaluation_request: Optional["EvaluateRequest"] = betterproto.message_field(
        6, optional=True, group="_evaluation_request"
    )
    evaluate_summary: str = betterproto.string_field(7)
    evaluation_time: int = betterproto.int64_field(8)


@dataclass(eq=False, repr=False)
class Turn(betterproto.Message):
    request: "ChatCompletionRequest" = betterproto.message_field(1)
    response: "ChatCompletionResponse" = betterproto.message_field(2)
    eval_request: "EvaluateRequest" = betterproto.message_field(3)
    application: "_application__.Application" = betterproto.message_field(4)


@dataclass(eq=False, repr=False)
class SessionMessage(betterproto.Message):
    id: int = betterproto.int64_field(1)
    application_id: int = betterproto.int64_field(2)
    company_id: int = betterproto.int64_field(3)
    application_action_id: int = betterproto.int64_field(4)
    session_id: str = betterproto.string_field(5)
    request_id: str = betterproto.string_field(6)
    role: str = betterproto.string_field(7)
    message: "ChatMessage" = betterproto.message_field(8)
    instructions: str = betterproto.string_field(9)
    date_created: int = betterproto.int64_field(10)


@dataclass(eq=False, repr=False)
class BacktestEvaluateRequest(betterproto.Message):
    id: int = betterproto.int64_field(1)
    date_created: int = betterproto.int64_field(2)
    run_id: int = betterproto.int64_field(3)
    company_id: int = betterproto.int64_field(4)
    application_id: int = betterproto.int64_field(5)
    backtest_application_id: int = betterproto.int64_field(6)
    session_id: str = betterproto.string_field(7)
    request_id: str = betterproto.string_field(8)
    inference_request_id: str = betterproto.string_field(9)
    action_type: str = betterproto.string_field(10)
    request_messages: List["ChatMessage"] = betterproto.message_field(11)
    response_message: "ChatCompletionMessage" = betterproto.message_field(12)
    sentinels: List["_sentinel__.Sentinel"] = betterproto.message_field(13)
    tools: List["Tool"] = betterproto.message_field(14)


class EvaluationServiceStub(betterproto.ServiceStub):
    async def evaluate(
        self,
        evaluate_request: "EvaluateRequest",
        *,
        timeout: Optional[float] = None,
        deadline: Optional["Deadline"] = None,
        metadata: Optional["MetadataLike"] = None
    ) -> "EvaluateResponse":
        return await self._unary_unary(
            "/chat.EvaluationService/Evaluate",
            evaluate_request,
            EvaluateResponse,
            timeout=timeout,
            deadline=deadline,
            metadata=metadata,
        )

    async def stream_correction(
        self,
        evaluate_request: "EvaluateRequest",
        *,
        timeout: Optional[float] = None,
        deadline: Optional["Deadline"] = None,
        metadata: Optional["MetadataLike"] = None
    ) -> AsyncIterator["ChatCompletionChunk"]:
        async for response in self._unary_stream(
            "/chat.EvaluationService/StreamCorrection",
            evaluate_request,
            ChatCompletionChunk,
            timeout=timeout,
            deadline=deadline,
            metadata=metadata,
        ):
            yield response

    async def evaluate_async(
        self,
        evaluate_request: "EvaluateRequest",
        *,
        timeout: Optional[float] = None,
        deadline: Optional["Deadline"] = None,
        metadata: Optional["MetadataLike"] = None
    ) -> "betterproto_lib_google_protobuf.Empty":
        return await self._unary_unary(
            "/chat.EvaluationService/EvaluateAsync",
            evaluate_request,
            betterproto_lib_google_protobuf.Empty,
            timeout=timeout,
            deadline=deadline,
            metadata=metadata,
        )

    async def backtest(
        self,
        backtest_evaluate_request: "BacktestEvaluateRequest",
        *,
        timeout: Optional[float] = None,
        deadline: Optional["Deadline"] = None,
        metadata: Optional["MetadataLike"] = None
    ) -> "EvaluateResponse":
        return await self._unary_unary(
            "/chat.EvaluationService/Backtest",
            backtest_evaluate_request,
            EvaluateResponse,
            timeout=timeout,
            deadline=deadline,
            metadata=metadata,
        )


class EvaluationServiceBase(ServiceBase):

    async def evaluate(self, evaluate_request: "EvaluateRequest") -> "EvaluateResponse":
        raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)

    async def stream_correction(
        self, evaluate_request: "EvaluateRequest"
    ) -> AsyncIterator["ChatCompletionChunk"]:
        raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
        yield ChatCompletionChunk()

    async def evaluate_async(
        self, evaluate_request: "EvaluateRequest"
    ) -> "betterproto_lib_google_protobuf.Empty":
        raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)

    async def backtest(
        self, backtest_evaluate_request: "BacktestEvaluateRequest"
    ) -> "EvaluateResponse":
        raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)

    async def __rpc_evaluate(
        self, stream: "grpclib.server.Stream[EvaluateRequest, EvaluateResponse]"
    ) -> None:
        request = await stream.recv_message()
        response = await self.evaluate(request)
        await stream.send_message(response)

    async def __rpc_stream_correction(
        self, stream: "grpclib.server.Stream[EvaluateRequest, ChatCompletionChunk]"
    ) -> None:
        request = await stream.recv_message()
        await self._call_rpc_handler_server_stream(
            self.stream_correction,
            stream,
            request,
        )

    async def __rpc_evaluate_async(
        self,
        stream: "grpclib.server.Stream[EvaluateRequest, betterproto_lib_google_protobuf.Empty]",
    ) -> None:
        request = await stream.recv_message()
        response = await self.evaluate_async(request)
        await stream.send_message(response)

    async def __rpc_backtest(
        self, stream: "grpclib.server.Stream[BacktestEvaluateRequest, EvaluateResponse]"
    ) -> None:
        request = await stream.recv_message()
        response = await self.backtest(request)
        await stream.send_message(response)

    def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
        return {
            "/chat.EvaluationService/Evaluate": grpclib.const.Handler(
                self.__rpc_evaluate,
                grpclib.const.Cardinality.UNARY_UNARY,
                EvaluateRequest,
                EvaluateResponse,
            ),
            "/chat.EvaluationService/StreamCorrection": grpclib.const.Handler(
                self.__rpc_stream_correction,
                grpclib.const.Cardinality.UNARY_STREAM,
                EvaluateRequest,
                ChatCompletionChunk,
            ),
            "/chat.EvaluationService/EvaluateAsync": grpclib.const.Handler(
                self.__rpc_evaluate_async,
                grpclib.const.Cardinality.UNARY_UNARY,
                EvaluateRequest,
                betterproto_lib_google_protobuf.Empty,
            ),
            "/chat.EvaluationService/Backtest": grpclib.const.Handler(
                self.__rpc_backtest,
                grpclib.const.Cardinality.UNARY_UNARY,
                BacktestEvaluateRequest,
                EvaluateResponse,
            ),
        }
