# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/jerx.prompt.ipynb.

# %% auto 0
__all__ = ['log', 'DEFAULT_SYSTEM_PROMPT_TEMPLATE', 'DEFAULT_RELATION_SET_PROMPT_TEMPLATE',
           'DEFAULT_SIMPLE_SYSTEM_PROMPT_TEMPLATE', 'JERXChatFormatter', 'JERXSimpleChatFormatter']

# %% ../../nbs/jerx.prompt.ipynb 3
import random
from dataclasses import dataclass
from typing import Iterable, Callable, Any, Generator

from ..logging import get_logger

log = get_logger(__name__)

# %% ../../nbs/jerx.prompt.ipynb 6
DEFAULT_SYSTEM_PROMPT_TEMPLATE = """You are a helpful assistant that extracts up to {max_triplets} entity-relation-entity triplets from given text. Use '{delimiter}' as delimiter and provide one triplet per line. The entities in a triplet must be different.
{relation_set_prompt}
""".strip()

DEFAULT_RELATION_SET_PROMPT_TEMPLATE = """Here are the list of relations that you can use:
{relation_set}
""".strip()

@dataclass
class JERXChatFormatter:
    system_prompt_template: str = DEFAULT_SYSTEM_PROMPT_TEMPLATE
    relation_set_prompt_template: str = DEFAULT_RELATION_SET_PROMPT_TEMPLATE
    relation_set: set|None = None
    max_triplets_margin: int = 0
    delimiter: str = " | "

    def __post_init__(self):
        if self.relation_set:
            self.relation_set = sorted(self.relation_set)

    def format(self, batch: list[dict], max_triplets: int | None = None):
        if "triplets" not in batch[0]:
            assert len(batch) == 1, "Only one example is allowed when 'triplets' is not present"
        if max_triplets is None:
            if "triplets" in batch[0]:
                max_triplets = max([len(example['triplets']) for example in batch]) + self.max_triplets_margin
            else:
                max_triplets = random.randint(15, 20)
        messages = [
            self.make_system_message(max_triplets),
            *[message for example in batch for message in self.make_messages(example)],
        ]
        return {'messages': messages}

    def make_system_message(self, max_triplets: int) -> str:
        rsp = self.relation_set_prompt_template.format(relation_set=','.join(self.relation_set)) if self.relation_set else ""
        content = self.system_prompt_template.format(max_triplets=max_triplets, delimiter=self.delimiter, relation_set_prompt=rsp)
        return {"role": "system", "content": content}

    def make_messages(self, example: dict) -> Generator[dict, None, None]:
        yield {"role": "user", "content": example["text"]}
        if "triplets" in example:
            yield {"role": "assistant", "content": self._format_triplets(example["triplets"])}

    def _format_triplets(self, triplets: Iterable[str]) -> str:
        return '\n'.join(triplets)

# %% ../../nbs/jerx.prompt.ipynb 9
DEFAULT_SIMPLE_SYSTEM_PROMPT_TEMPLATE = """You are a helpful assistant that extracts entityA-relation-entityB triplets from given text. Use '{delimiter}' as delimiter and provide one triplet per line.
""".strip()

@dataclass
class JERXSimpleChatFormatter:
    system_prompt_template: str = DEFAULT_SIMPLE_SYSTEM_PROMPT_TEMPLATE
    delimiter: str = " | "

    def format(self, batch: list[dict]):
        if "triplets" not in batch[0]:
            assert len(batch) == 1, "Only one example is allowed when 'triplets' is not present"
        messages = [
            self.make_system_message(),
            *[message for example in batch for message in self.make_messages(example)],
        ]
        return {'messages': messages}

    def make_system_message(self) -> str:
        content = self.system_prompt_template.format(delimiter=self.delimiter)
        return {"role": "system", "content": content}

    def make_messages(self, example: dict) -> Generator[dict, None, None]:
        yield {"role": "user", "content": example["text"]}
        if "triplets" in example:
            yield {"role": "assistant", "content": self._format_triplets(example["triplets"])}

    def _format_triplets(self, triplets: Iterable[str]) -> str:
        return '\n'.join(triplets)
