# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/musique.baseline.ipynb.

# %% auto 0
__all__ = ['make_docs', 'format_question', 'BaselineMHQA', 'benchmark']

# %% ../../nbs/musique.baseline.ipynb 3
from typing import Callable

import pandas as pd
from tqdm.auto import tqdm

from ..jerx.reward.llm import QuestionAnsweringResult
from .eval import calculate_metrics, compare_answers

tqdm.pandas()

# %% ../../nbs/musique.baseline.ipynb 4
def make_docs(example, only_supporting=False):
    ps = example["paragraphs"]
    for p in ps:
        if only_supporting and not p["is_supporting"]:
            continue
        idx = p["idx"]
        title = p["title"]
        body = p["paragraph_text"]
        is_supporting = p["is_supporting"]
        text = f"# {title}\n{body}"
        yield dict(
            text=text,
            metadata={"parent_id": example["id"], "idx": idx, "is_supporting": is_supporting},
        )

# %% ../../nbs/musique.baseline.ipynb 5
def format_question(example):
    return example['question']
    # sub_questions = '\n'.join([f"  Sub-question {i+1}: {item['question']}" for i, item in enumerate(example['question_decomposition'])])
    # return f"\n{sub_questions}"

class BaselineMHQA:
    def __init__(self, qa_func, only_supporting: bool = True):
        self.qa_func = qa_func
        self.only_supporting = only_supporting

    def _answer(self, example) -> QuestionAnsweringResult:
        documents = list(make_docs(example, only_supporting=self.only_supporting))
        context = "\n\n".join([doc["text"] for doc in documents])
        return self.qa_func(context=context, question=format_question(example))

    def answer(self, example, ignore_errors: bool = False) -> QuestionAnsweringResult:
        try:
            output = self._answer(example)
        except Exception as exc:
            if ignore_errors:
                id = example['id']
                print(f"Failed to answer the question {id}\n{exc}")
                output = QuestionAnsweringResult(reasoning="", answer="N/A", raw_output=str(exc))
            else:
                raise
        return output

# %% ../../nbs/musique.baseline.ipynb 6
def benchmark(dataf: pd.DataFrame, qa_func: Callable, only_supporting: bool = True) -> tuple[pd.DataFrame, dict]:
    mhqa = BaselineMHQA(qa_func, only_supporting = only_supporting)

    def process(example):
        output = mhqa.answer(example)
        example['predicted_answer'] = output.answer
        example['raw_llm_output'] = output
        return example
    
    dataf = dataf.progress_apply(process, axis=1)
    dataf = compare_answers(dataf)
    scores = calculate_metrics(dataf)
    scores['fuzzy_match'] = dataf['fuzzy_match'].mean()
    return dataf, scores
