# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/lang.qdecomp.ipynb.

# %% auto 0
__all__ = ['QUESTION_DECOMPOSITION_SYSTEM_PROMPT_TEMPLATE', 'make_chat_prompt_template', 'parse_sub_questions',
           'make_question_decomposer']

# %% ../../nbs/lang.qdecomp.ipynb 3
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

# %% ../../nbs/lang.qdecomp.ipynb 4
QUESTION_DECOMPOSITION_SYSTEM_PROMPT_TEMPLATE = """
Decompose the given question into 2 sub-questions such that when sub-questions are answered, the original question can be answered correctly.
The second subquestion must refer to the answer of the first question by `#1` as in the examples below. Do not create open-ended sub-questions like "Who is ..." or "How is ...".

Question: What year saw the creation of the region where the county of Hertfordshire is located?
Sub-questions:
1. In which state is Hertfordshire located?
2. When was #1 birthed?

Question: When was the institute that owned The Collegian founded?
Sub-questions:
1. Which institute does own The Collegian?
2. When #1 founded?

""".strip()


def make_chat_prompt_template() -> ChatPromptTemplate:
    return ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(QUESTION_DECOMPOSITION_SYSTEM_PROMPT_TEMPLATE),
            HumanMessagePromptTemplate.from_template("Question: {question}"),
        ]
    )


def parse_sub_questions(output: str):
    flag = False
    for line in output.splitlines():
        if line.lower().startswith("sub-questions"):
            flag = True
            continue
        if flag:
            yield line.split(".", 1)[-1].strip()


def make_question_decomposer(llm=None):
    if llm is None:
        llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-1106")
    
    chain = make_chat_prompt_template() | llm | StrOutputParser()
    
    def func(question):
        out = chain.invoke(dict(question=question))
        return list(parse_sub_questions(out))
    return func
