# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/jerx.offline.llm.ipynb.

# %% auto 0
__all__ = ['log', 'parse_triplet_response', 'make_kg_triplet_extract_fn']

# %% ../../../nbs/jerx.offline.llm.ipynb 3
import json
from pathlib import Path
from ..utils import parse_triplets
from ...logging import get_logger

log = get_logger(__name__)

# %% ../../../nbs/jerx.offline.llm.ipynb 5
def parse_triplet_response(response: str, *args, **kwargs) -> list[tuple[str, str, str]]:
    triplets = parse_triplets(response.strip())
    return [(e1, rel, e2) if e1 != e2 else (e1, rel, e2 + "(obj)") for e1, rel, e2 in triplets]


def make_kg_triplet_extract_fn(inference_cache_filepath: Path):
    mapping = {}
    with open(inference_cache_filepath, "r") as f:
        for i, line in enumerate(f):
            record = json.loads(line.strip())
            mapping[record["text"]] = record["generation"]

    def extract_kg_triplets(text: str) -> list[tuple[str, str, str]]:
        return parse_triplet_response(mapping[text])

    return extract_kg_triplets
