import openai
import time
import uuid
from AIBridge.exceptions import OpenAIException, AIBridgeException
from AIBridge.prompts.prompt_completion import Completion
from AIBridge.ai_services.ai_abstraction import AIInterface
from AIBridge.output_validation.active_validator import ActiveValidator
import json
from AIBridge.constant.common import get_function_from_json, parse_fromat, parse_api_key


class OpenAIService(AIInterface):
    """
    Base class for OpenAI Services
    """

    @classmethod
    def generate(
        self,
        prompts: list[str] = [],
        prompt_ids: list[str] = [],
        prompt_data: list[dict] = [],
        variables: list[dict] = [],
        output_format: list[str] = [],
        format_strcture: list[str] = [],
        model="gpt-3.5-turbo",
        variation_count: int = 1,
        max_tokens: int = 3500,
        temperature: float = 0.5,
        message_queue=False,
        api_key=None,
        output_format_parse=True,
    ):
        try:
            if prompts and prompt_ids:
                raise OpenAIException(
                    "please provide either prompts or prompts ids at atime"
                )
            if not prompts and not prompt_ids:
                raise OpenAIException(
                    "Either provide prompts or prompts ids to genrate the data"
                )
            if prompt_ids:
                prompts_list = Completion.create_prompt_from_id(
                    prompt_ids=prompt_ids,
                    prompt_data_list=prompt_data,
                    variables_list=variables,
                )
            if prompts:
                if prompt_data or variables:
                    prompts_list = Completion.create_prompt(
                        prompt_list=prompts,
                        prompt_data_list=prompt_data,
                        variables_list=variables,
                    )
                else:
                    prompts_list = prompts
            if output_format:
                if len(output_format) != len(prompts_list):
                    raise AIBridgeException(
                        "length of output_format must be equal to length of the prompts"
                    )
            if format_strcture:
                if len(format_strcture) != len(prompts_list):
                    raise AIBridgeException(
                        "length of format_strcture must be equal to length of the prompts"
                    )
            updated_prompts = []
            for _prompt in prompts_list:
                format = None
                format_str = None
                if output_format:
                    format = output_format[prompts_list.index(_prompt)]
                if format_strcture:
                    format_str = format_strcture[prompts_list.index(_prompt)]
                if output_format_parse:
                    u_prompt = parse_fromat(
                        _prompt, format=format, format_structure=format_str
                    )
                    updated_prompts.append(u_prompt)
            if not updated_prompts:
                updated_prompts = prompts_list
            if message_queue:
                id = uuid.uuid4()
                message_data = {
                    "id": str(id),
                    "prompts": json.dumps(updated_prompts),
                    "model": model,
                    "variation_count": variation_count,
                    "max_tokens": max_tokens,
                    "temperature": temperature,
                    "ai_service": "open_ai",
                    "output_format": json.dumps(output_format),
                    "format_structure": json.dumps(format_strcture),
                    "api_key": api_key,
                }
                message = {"data": json.dumps(message_data)}
                from AIBridge.queue_integration.message_queue import MessageQ

                MessageQ.mq_enque(message=message)
                return {"response_id": str(id)}
            return self.get_response(
                updated_prompts,
                model,
                variation_count,
                max_tokens,
                temperature,
                output_format,
                format_strcture,
                api_key=api_key,
            )
        except AIBridgeException as e:
            raise OpenAIException(f"Error in generating AI data {e}")

    @classmethod
    def execute_text_prompt(
        self, api_key, model, messages, n, max_tokens=3500, temperature=0.5
    ):
        openai.api_key = api_key
        return openai.ChatCompletion.create(
            model=model,
            messages=messages,
            n=n,
            max_tokens=max_tokens,
            temperature=temperature,
        )

    @classmethod
    def execute_prompt_function_calling(
        self,
        api_key,
        model,
        messages,
        n,
        functions_call,
        max_tokens=3500,
        temperature=0.5,
    ):
        openai.api_key = api_key
        return openai.ChatCompletion.create(
            model=model,
            messages=messages,
            n=n,
            functions=functions_call,
            function_call="auto",
        )

    @classmethod
    def get_response(
        self,
        prompts,
        model="gpt-3.5-turbo",
        variation_count=1,
        max_tokens=3500,
        temperature=0.5,
        output_format=[],
        format_structure=[],
        api_key=None,
    ):
        try:
            if output_format:
                if isinstance(output_format, str):
                    output_format = json.loads(output_format)
            if format_structure:
                if isinstance(format_structure, str):
                    format_structure = json.loads(format_structure)
            if not prompts:
                raise OpenAIException("No prompts provided")
            api_key = api_key if api_key else parse_api_key("open_ai")
            openai.api_key = api_key
            message_data = []
            model_output = []
            token_used = 0
            _formatter = "string"
            for prompt in prompts:
                if output_format:
                    _formatter = output_format[prompts.index(prompt)]
                message_data.append({"role": "user", "content": prompt})
                if _formatter != "json":
                    response = self.execute_text_prompt(
                        api_key,
                        model=model,
                        messages=message_data,
                        n=variation_count,
                        max_tokens=3500,
                        temperature=0.5,
                    )
                else:
                    schema = json.loads(format_structure[prompts.index(prompt)])
                    functions = [get_function_from_json(schema)]
                    response = self.execute_prompt_function_calling(
                        api_key=api_key,
                        model=model,
                        messages=message_data,
                        n=variation_count,
                        functions_call=functions,
                    )
                message_data.append(
                    {
                        "role": response["choices"][0]["message"]["role"],
                        "content": response["choices"][0]["message"]["content"]
                        if response["choices"][0]["message"]["content"]
                        else response["choices"][0]["message"]["function_call"][
                            "arguments"
                        ],
                    }
                )
                tokens = response["usage"]["total_tokens"]
                token_used = token_used + tokens
                for res in response["choices"]:
                    index = response["choices"].index(res)
                    content = (
                        res["message"]["content"]
                        if res["message"]["content"]
                        else res["message"]["function_call"]["arguments"]
                    )
                    if output_format:
                        if _formatter != "string":
                            _validate_obj = ActiveValidator.get_active_validator(
                                _formatter
                            )
                            try:
                                content = _validate_obj.validate(
                                    content,
                                    schema=format_structure[prompts.index(prompt)]
                                    if format_structure
                                    else None,
                                )
                            except AIBridgeException as e:
                                content_error = {
                                    "error": f"{e}",
                                    "ai_response": content,
                                }
                                content = json.dumps(content_error)
                    if index >= len(model_output):
                        model_output.append({"data": [content]})
                    else:
                        model_output[index]["data"].append(content)
            message_value = {
                "items": {
                    "response": model_output,
                    "token_used": token_used,
                    "created_at": time.time(),
                    "ai_service": "open_ai",
                }
            }
            return message_value
        except AIBridgeException as e:
            raise OpenAIException(f"{e}")
