# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/jerx.dataset.docred.ipynb.

# %% auto 0
__all__ = ['puncts', 'join_tokens', 'extract_sentences', 'extract_text', 'extract_triplets', 'transform_docred']

# %% ../../../nbs/jerx.dataset.docred.ipynb 3
import string

# %% ../../../nbs/jerx.dataset.docred.ipynb 4
puncts = set(string.punctuation)

def join_tokens(tokens):
    return ''.join(token if token in puncts else " " + token for token in tokens).strip()

def extract_sentences(example):
    for sent_tokens in example['sents']:
        yield join_tokens(sent_tokens).replace("- ", "-")

def extract_text(example):
    return ' '.join(extract_sentences(example))

def extract_triplets(example):
    for head, rel, tail in zip(example['labels']['head'], example['labels']['relation_text'], example['labels']['tail']):
        yield [example['vertexSet'][head][0]['name'], rel, example['vertexSet'][tail][0]['name']]

def transform_docred(example, delimiter="|"):
    triplets = [delimiter.join(triplet) for triplet in extract_triplets(example)]
    text = extract_text(example)
    return {'text': text, 'triplets': triplets}
