# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/jerx.reward.heuristic.ipynb.

# %% auto 0
__all__ = ['log', 'compute_heuristic_reward']

# %% ../../../nbs/jerx.reward.heuristic.ipynb 3
from ...logging import get_logger

log = get_logger(__name__)

# %% ../../../nbs/jerx.reward.heuristic.ipynb 4
def compute_heuristic_reward(generation: str, delimiter: str = "|") -> float:
    lines = generation.splitlines()
    if len(lines) < 2:
        return 0

    triplets = set([line for line in lines if len(line.split(delimiter)) == 3])
    if len(triplets) > 30:
        return 0

    entities = set()
    relations = set()
    for triplet in triplets:
        subj, relation, obj = triplet.split(delimiter)
        entities.add(subj.strip())
        entities.add(obj.strip())
        relations.add(relation.strip())

    reward = 0
    if len(entities) > 5:
        reward += 0.3
    if len(relations) > 5:
        reward += 0.5

    if len(triplets) > 5:
        reward += 0.1
    elif len(triplets) >= 3:
        reward += 0.05

    if (len(triplets) / len(lines)) > 0.8:
        reward += 0.05

    return reward


