# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/musique.baseline.ipynb.

# %% auto 0
__all__ = ['make_docs', 'BaselineSingleHop', 'BaselineMultiHop', 'benchmark']

# %% ../../nbs/musique.baseline.ipynb 3
import json
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from typing import Callable

import pandas as pd
from tqdm.auto import tqdm

from ..jerx.reward.llm import QuestionAnsweringResult
from .eval import calculate_metrics, compare_answers

tqdm.pandas()

# %% ../../nbs/musique.baseline.ipynb 4
def make_docs(example):
    ps = example["paragraphs"]
    for p in ps:
        idx = p["idx"]
        title = p["title"]
        body = p["paragraph_text"]
        is_supporting = p["is_supporting"]
        text = f"# {title}\n{body}"
        yield dict(
            text=text,
            is_supporting=is_supporting,
            parent_id=example["id"],
            id=idx,
        )

# %% ../../nbs/musique.baseline.ipynb 5
class BaselineSingleHop:
    def __init__(self, qa_func, retrieval_func):
        self.qa_func = qa_func
        self.retrieval_func = retrieval_func

    def _call(self, example) -> QuestionAnsweringResult:
        docs = list(make_docs(example))
        retrieved_docs = self.retrieval_func(docs, example['question'])
        context = "\n\n".join([doc["text"] for doc in retrieved_docs])
        return self.qa_func(context=context, question=example['question'])

    def __call__(self, example, ignore_errors: bool = False) -> QuestionAnsweringResult:
        try:
            output = self._call(example)
        except Exception as exc:
            if ignore_errors:
                id = example['id']
                print(f"Failed to answer the question {id}\n{exc}")
                output = QuestionAnsweringResult(reasoning="", answer="N/A", raw_output=str(exc))
            else:
                raise
        return output

# %% ../../nbs/musique.baseline.ipynb 6
class BaselineMultiHop:
    def __init__(self, qa_func, retrieval_func):
        self.qa_func = qa_func
        self.retrieval_func = retrieval_func

    def _call(self, example) -> QuestionAnsweringResult:
        docs = list(make_docs(example))
        
        # First question
        question1 = example["question_decomposition"][0]["question"]
        query1 = question1
        docs1 = self.retrieval_func(docs, query1)
        context1 = "\n".join(doc['text'] for doc in docs1)
        result1 = self.qa_func(context=context1, question=question1)
        hop1 = {
            "question": question1,
            "query" : query1,
            "context": context1,
            "answer": result1.answer,
            "reasoning": result1.reasoning,
        }

        # Second question
        if result1.answer == "N/A":
            return {
                "answer": "N/A",
                "reasoning": result1.reasoning,
                "hops": [hop1],
            }

        question2 = example["question_decomposition"][1]["question"]
        question2 = question2.replace("#1", result1.answer)
        query2 = question2
        docs2 = self.retrieval_func(docs, query2)
        context2 = "\n".join(doc['text'] for doc in docs2)
        result2 = self.qa_func(context=context2, question=question2)
        hop2 = {
            "question": question2,
            "query": query2,
            "context": context2,
            "answer": result2.answer,
            "reasoning": result2.reasoning,
        }
        return QuestionAnsweringResult(answer=result2.answer, reasoning=result2.reasoning, raw_output=json.dumps([hop1, hop2]))

    def __call__(self, example, ignore_errors: bool = False) -> QuestionAnsweringResult:
        try:
            output = self._call(example)
        except Exception as exc:
            if ignore_errors:
                id = example['id']
                print(f"Failed to answer the question {id}\n{exc}")
                output = QuestionAnsweringResult(reasoning="", answer="N/A", raw_output=str(exc))
            else:
                raise
        return output

# %% ../../nbs/musique.baseline.ipynb 7
def benchmark(
    dataf: pd.DataFrame,
    pipeline: Callable,
    ignore_errors: bool = False,
    n_workers: int = 8,
) -> tuple[pd.DataFrame, dict]:

    def process(example):
        output = pipeline(example, ignore_errors=ignore_errors)
        example["predicted_answer"] = output.answer
        example["raw_llm_output"] = output
        return example

    rows = [] 
    with ThreadPoolExecutor(max_workers=n_workers) as executor:
        futures = [executor.submit(process, row) for _, row in dataf.iterrows()]
        for future in tqdm(as_completed(futures), total=len(dataf), desc="Processing samples"):
            rows.append(future.result())
    
    dataf = pd.DataFrame(rows)
    dataf = compare_answers(dataf)
    scores = calculate_metrics(dataf)
    scores["fuzzy_match"] = dataf["fuzzy_match"].mean()
    return dataf, scores
