# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/ml.kg.cons.ipynb.

# %% auto 0
__all__ = ['Entity', 'Relation', 'Triplet', 'DEFAULT_RELATION_SET_PROMPT_TEMPLATE', 'DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE',
           'DEFAULT_SYSTEM_PROMPT_TEMPLATE', 'evaluate_joint_er_extraction', 'evaluate_joint_er_extractions',
           'parse_triplet_strings', 'parse_triplets', 'format_few_shot_example', 'format_few_shot_examples',
           'ERXFormatter']

# %% ../../../nbs/ml.kg.cons.ipynb 3
from dataclasses import dataclass
from functools import cache
from typing import TypeAlias, Iterable, List, Set, Tuple, Callable, Any, Dict
import numpy as np
from ..llm.utils import LLAMA2_CHAT_PROMPT_TEMPLATE

# %% ../../../nbs/ml.kg.cons.ipynb 4
Entity: TypeAlias = str | tuple[str, str]
Relation: TypeAlias = str
Triplet: TypeAlias = tuple[Entity, Relation, Entity]

# %% ../../../nbs/ml.kg.cons.ipynb 5
def evaluate_joint_er_extraction(*, reference: Iterable[Triplet], prediction: Iterable[Triplet]):
    """
    Example: [(('John', 'PERSON'), 'works_at', ('Google', 'ORG'))]
    """

    reference_set = set(reference)
    prediction_set = set(prediction)
    assert len(reference) == len(reference_set), "Duplicates found in references"

    TP = len(reference_set & prediction_set)
    FP = len(prediction_set - reference_set)
    FN = len(reference_set - prediction_set)
    
    # Calculate metrics
    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1_score
    }

def evaluate_joint_er_extractions(*, references: Iterable[Iterable[Triplet]], predictions: Iterable[Iterable[Triplet]]):
    score_dicts = [
        evaluate_joint_er_extraction(reference=reference, prediction=prediction) 
        for reference, prediction in zip(references, predictions)
    ]
    return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()}

# %% ../../../nbs/ml.kg.cons.ipynb 8
def parse_triplet_strings(text: str, delimiter: str=" | ") -> List[str]:
    return [line for line in text.splitlines() if line and line.count(delimiter) == 2]

def parse_triplets(text: str, delimiter: str=" | ") -> List[Triplet]:
    return [tuple(triplet_string.split(delimiter)) for triplet_string in parse_triplet_strings(text, delimiter=delimiter)]

# %% ../../../nbs/ml.kg.cons.ipynb 10
def format_few_shot_example(example, text_prefix="# Text\n", triplets_prefix="# Triplets\n"):
    text = example['text']
    triplets = '\n'.join(example['triplets'])
    return f"{text_prefix}{text}\n{triplets_prefix}{triplets}"

def format_few_shot_examples(examples):
    return "\n\n".join([format_few_shot_example(example) for example in examples])

# %% ../../../nbs/ml.kg.cons.ipynb 11
DEFAULT_RELATION_SET_PROMPT_TEMPLATE = """Here are the list of relations that you can use:
{relation_set}

"""

DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE = """Here are a few examples:
{few_shot_examples}

"""

DEFAULT_SYSTEM_PROMPT_TEMPLATE = """You are helpful assistant that extracts entity-relation-entity triplets from given text.
{relation_set_prompt}{few_shot_prompt}
Use the same format for triplets as in examples provided above. No explanation needed, just output the triplets starting from the next line.
""".strip()


@dataclass
class ERXFormatter:
    chat_prompt_template: str = LLAMA2_CHAT_PROMPT_TEMPLATE
    system_prompt_template: str = DEFAULT_SYSTEM_PROMPT_TEMPLATE
    few_shot_examples_prompt_template: str = DEFAULT_FEW_SHOT_EXAMPLES_PROMPT_TEMPLATE
    few_shot_examples: List[Dict] | None = None
    relation_set_prompt_template: str = DEFAULT_RELATION_SET_PROMPT_TEMPLATE
    relation_set: set | None = None

    def __post_init__(self):
        self._system_prompt = self.make_system_prompt()

    def format_for_inference(self, example: Dict):
        user_message = example['text']
        example['text'] = self.chat_prompt_template.format(system_prompt=self.system_prompt, user_message=user_message)
        return example

    def format_for_train(self, example: Dict):
        example['text'] = self.format_for_inference(example)['text'] + " " + '\n'.join(example['triplets'])
        return example

    @property
    def system_prompt(self) -> str:
        return self._system_prompt

    def make_system_prompt(self) -> str:
        rsp = self.relation_set_prompt_template.format(relation_set=','.join(sorted(self.relation_set))) if self.relation_set else ""
        fsp = self.few_shot_examples_prompt_template.format(few_shot_examples=format_few_shot_examples(self.few_shot_examples)) if self.few_shot_examples else ""
        return self.system_prompt_template.format(relation_set_prompt=rsp, few_shot_prompt=fsp)

