#!/usr/bin/env python3
# arxa/arxa/cli.py

import os
import sys
import argparse
import tempfile
import logging
import subprocess
import json
import hashlib

# Import rich’s logging handler for better log formatting.
try:
    from rich.logging import RichHandler
except ImportError:
    RichHandler = None

from . import __version__  # Import version from __init__.py
from .config import load_config
from .arxiv_utils import search_arxiv_by_id_list
from .pdf_utils import download_pdf_from_arxiv, extract_text_from_pdf, sanitize_filename
from .research_review import generate_research_review
from .repo_utils import extract_github_url, clone_repo
from .prompts import PROMPT_PREFIX, PROMPT_SUFFIX  # For printing the prompt template

def configure_logging(quiet: bool) -> None:
    """
    Configure logging to use rich formatting by default unless quiet is True.
    """
    # If not quiet, set level to DEBUG to aid troubleshooting.
    log_level = logging.DEBUG if not quiet else logging.INFO
    if not quiet and RichHandler is not None:
        handler = RichHandler(markup=True)
    else:
        handler = logging.StreamHandler()

    logging.basicConfig(
        level=log_level,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[handler]
    )

logger = logging.getLogger(__name__)

def json_to_markdown(data: dict) -> str:
    """
    Convert a JSON response (with research_notes) into Markdown.
    Expected structure:
      {
         "research_notes": {
             "paper_information": { ... },
             "summary": { ... },
             "background_and_related_work": { ... },
             "methodology": { ... },
             "experimental_evaluation": { ... },
             "strengths": { ... },
             "weaknesses_and_critiques": { ... },
             "future_work_and_open_questions": { ... },
             "personal_review": { ... },
             "additional_notes": { ... }
         }
      }
    """
    logger.debug("Inside json_to_markdown; received data: %s", json.dumps(data, indent=2))
    review_data = json.loads(data.get("review", {}))
    rn = review_data.get("research_notes", {})
    if not rn:
        logger.warning("No 'research_notes' key found in JSON data!")
    md_lines = []
    md_lines.append("# Research Review\n")

    # 1. Paper Information
    info = rn.get("paper_information", {})
    logger.debug("Paper Information: %s", json.dumps(info, indent=2))
    md_lines.append("## 1. Paper Information")
    if info:
        if title := info.get("title"):
            md_lines.append(f"- **Title:** {title}  ")
        if authors := info.get("authors"):
            md_lines.append(f"- **Authors:** {', '.join(authors)}  ")
        if arxiv_link := info.get("arxiv_link"):
            md_lines.append(f"- **ArXiv Link:** [{arxiv_link}]({arxiv_link})  ")
        if date := info.get("date_of_submission"):
            md_lines.append(f"- **Date of Submission:** {date}  ")
        if field := info.get("field_of_study"):
            md_lines.append(f"- **Field of Study:** {field}  ")
        if keywords := info.get("keywords"):
            md_lines.append(f"- **Keywords:** {', '.join(keywords)}  ")
        if repo := info.get("code_repository"):
            md_lines.append(f"- **Code Repository:** [{repo}]({repo})  ")
    else:
        logger.debug("No paper_information section found.")
    md_lines.append("")  # add a blank line

    # 2. Summary
    summary = rn.get("summary", {})
    logger.debug("Summary: %s", json.dumps(summary, indent=2))
    md_lines.append("## 2. Summary")
    if summary:
        if ps := summary.get("problem_statement"):
            md_lines.append(f"- **Problem Statement:**  \n  {ps}")
        if contributions := summary.get("main_contributions"):
            md_lines.append("- **Main Contributions:**")
            for idx, item in enumerate(contributions, start=1):
                md_lines.append(f"  {idx}. {item}")
        if findings := summary.get("key_findings"):
            md_lines.append("- **Key Findings:**")
            for item in findings:
                md_lines.append(f"  - {item}")
        if meth_overview := summary.get("methodology_overview"):
            md_lines.append(f"- **Methodology Overview:**  \n  {meth_overview}")
        if concl := summary.get("conclusion"):
            md_lines.append(f"- **Conclusion:**  \n  {concl}")
    else:
        logger.debug("No summary section found.")
    md_lines.append("")

    # 3. Background & Related Work
    brw = rn.get("background_and_related_work", {})
    logger.debug("Background & Related Work: %s", json.dumps(brw, indent=2))
    md_lines.append("## 3. Background & Related Work")
    if brw:
        if prior := brw.get("prior_work_referenced"):
            md_lines.append(f"- **Prior Work Referenced:**  \n  {prior}")
        if diff := brw.get("how_it_differs"):
            md_lines.append(f"- **How It Differs:**  \n  {diff}")
        if gaps := brw.get("gaps_addressed"):
            md_lines.append(f"- **Gaps Addressed:**  \n  {gaps}")
    else:
        logger.debug("No background_and_related_work section found.")
    md_lines.append("")

    # 4. Methodology
    meth = rn.get("methodology", {})
    logger.debug("Methodology: %s", json.dumps(meth, indent=2))
    md_lines.append("## 4. Methodology")
    if meth:
        if approach := meth.get("approach_taken"):
            md_lines.append(f"- **Approach Taken:**  \n  {approach}")
        if techniques := meth.get("key_techniques_used"):
            md_lines.append("- **Key Techniques Used:**")
            for item in techniques:
                md_lines.append(f"  - {item}")
        if datasets := meth.get("datasets_benchmarks_used"):
            md_lines.append(f"- **Datasets / Benchmarks Used:**  \n  {datasets}")
        if impl := meth.get("implementation_details"):
            md_lines.append(f"- **Implementation Details:**  \n  {impl}")
        if repro := meth.get("reproducibility"):
            md_lines.append(f"- **Reproducibility:**  \n  {repro}")
    else:
        logger.debug("No methodology section found.")
    md_lines.append("")

    # 5. Experimental Evaluation
    exp_eval = rn.get("experimental_evaluation", {})
    logger.debug("Experimental Evaluation: %s", json.dumps(exp_eval, indent=2))
    md_lines.append("## 5. Experimental Evaluation")
    if exp_eval:
        if metrics := exp_eval.get("evaluation_metrics"):
            md_lines.append(f"- **Evaluation Metrics:**  \n  {metrics}")
        if res_sum := exp_eval.get("results_summary"):
            md_lines.append(f"- **Results Summary:**  \n  {res_sum}")
        if baseline := exp_eval.get("baseline_comparisons"):
            md_lines.append(f"- **Baseline Comparisons:**  \n  {baseline}")
        if ablation := exp_eval.get("ablation_studies"):
            md_lines.append(f"- **Ablation Studies:**  \n  {ablation}")
        if limits := exp_eval.get("limitations_noted"):
            md_lines.append(f"- **Limitations Noted:**  \n  {limits}")
    else:
        logger.debug("No experimental_evaluation section found.")
    md_lines.append("")

    # 6. Strengths
    strengths = rn.get("strengths", {})
    logger.debug("Strengths: %s", json.dumps(strengths, indent=2))
    md_lines.append("## 6. Strengths")
    if strengths:
        if novelty := strengths.get("novelty_innovation"):
            md_lines.append(f"- **Novelty & Innovation:**  \n  {novelty}")
        if technical := strengths.get("technical_soundness"):
            md_lines.append(f"- **Technical Soundness:**  \n  {technical}")
        if clarity := strengths.get("clarity_and_organization"):
            md_lines.append(f"- **Clarity & Organization:**  \n  {clarity}")
        if impact := strengths.get("impact_applications"):
            md_lines.append(f"- **Impact & Applications:**  \n  {impact}")
    else:
        logger.debug("No strengths section found.")
    md_lines.append("")

    # 7. Weaknesses & Critiques
    weaknesses = rn.get("weaknesses_and_critiques", {})
    logger.debug("Weaknesses & Critiques: %s", json.dumps(weaknesses, indent=2))
    md_lines.append("## 7. Weaknesses & Critiques")
    if weaknesses:
        if flaws := weaknesses.get("unaddressed_assumptions_flaws"):
            md_lines.append(f"- **Unaddressed Assumptions / Flaws:**  \n  {flaws}")
        if bias := weaknesses.get("possible_biases_limitations"):
            md_lines.append(f"- **Possible Biases / Limitations:**  \n  {bias}")
        if repro_concern := weaknesses.get("reproducibility_concerns"):
            md_lines.append(f"- **Reproducibility Concerns:**  \n  {repro_concern}")
        if presentation := weaknesses.get("presentation_issues"):
            md_lines.append(f"- **Presentation Issues:**  \n  {presentation}")
    else:
        logger.debug("No weaknesses_and_critiques section found.")
    md_lines.append("")

    # 8. Future Work & Open Questions
    future = rn.get("future_work_and_open_questions", {})
    logger.debug("Future Work & Open Questions: %s", json.dumps(future, indent=2))
    md_lines.append("## 8. Future Work & Open Questions")
    if future:
        if suggested := future.get("suggested_improvements"):
            md_lines.append(f"- **Suggested Improvements:**  \n  {suggested}")
        if extensions := future.get("potential_extensions"):
            md_lines.append(f"- **Potential Extensions:**  \n  {extensions}")
        if open_prob := future.get("open_problems"):
            md_lines.append(f"- **Open Problems:**  \n  {open_prob}")
    else:
        logger.debug("No future_work_and_open_questions section found.")
    md_lines.append("")

    # 9. Personal Review
    personal = rn.get("personal_review", {})
    logger.debug("Personal Review: %s", json.dumps(personal, indent=2))
    md_lines.append("## 9. Personal Review")
    if personal:
        if overall := personal.get("overall_impression"):
            md_lines.append(f"- **Overall Impression:**  \n  {overall}")
        if significance := personal.get("significance_contributions"):
            md_lines.append(f"- **Significance of Contributions:**  \n  {significance}")
        if clarity := personal.get("clarity_and_organization"):
            md_lines.append(f"- **Clarity & Organization:**  \n  {clarity}")
        if rigor := personal.get("methodological_rigor"):
            md_lines.append(f"- **Methodological Rigor:**  \n  {rigor}")
        if repro := personal.get("reproducibility"):
            md_lines.append(f"- **Reproducibility:**  \n  {repro}")
    else:
        logger.debug("No personal_review section found.")
    md_lines.append("")

    # 10. Additional Notes
    additional = rn.get("additional_notes", {})
    logger.debug("Additional Notes: %s", json.dumps(additional, indent=2))
    md_lines.append("## 10. Additional Notes")
    if additional:
        if takeaways := additional.get("key_takeaways"):
            md_lines.append(f"- **Key Takeaways:**  \n  {takeaways}")
        if insights := additional.get("interesting_insights"):
            md_lines.append(f"- **Interesting Insights:**  \n  {insights}")
        if thoughts := additional.get("personal_thoughts"):
            md_lines.append(f"- **Personal Thoughts:**  \n  {thoughts}")
    else:
        logger.debug("No additional_notes section found.")
    md_lines.append("")

    final_md = "\n".join(md_lines)
    logger.debug("Final Markdown output:\n%s", final_md)
    return final_md

def main():
    parser = argparse.ArgumentParser(
        description="arxa: Generate research reviews from arXiv papers or PDFs, or start the FastAPI server."
    )
    # Flag to start the server.
    parser.add_argument("--server", action="store_true",
                        help="Start the FastAPI server instead of processing a paper/PDF.")

    # Add mutually exclusive arguments for arXiv id or local PDF.
    group = parser.add_mutually_exclusive_group()
    group.add_argument("-aid", help="arXiv ID of the paper (e.g. 1234.5678)")
    group.add_argument("-pdf", help="Path to a local PDF file")

    parser.add_argument("-o", "--output", help="Output markdown file for the review (ignored when --server is used)")
    parser.add_argument(
        "-p",
        "--provider",
        default="api.arxa.ai",  # default provider uses the remote server
        choices=["api.arxa.ai", "anthropic", "openai", "ollama", "deepseek", "fireworks"],
        help="LLM provider to use (default: api.arxa.ai)"
    )
    parser.add_argument(
        "-m",
        "--model",
        help="Model identifier/version (e.g., 'o3-mini'). When using the remote server, this will be ignored.",
        default="o3-mini"
    )
    parser.add_argument("-g", "--github", action="store_true",
                        help="Enable GitHub cloning if a GitHub URL is found in the review")
    parser.add_argument("-c", "--config", help="Path to configuration YAML file")
    parser.add_argument("--quiet", action="store_true", help="Disable rich output formatting (sets log level to INFO)")

    args = parser.parse_args()

    configure_logging(args.quiet)

    logger.info("Starting arxa version %s", __version__)
    logger.info("Arguments: provider=%s, model=%s", args.provider, args.model)
    if args.config:
        logger.info("Config file provided: %s", args.config)
        try:
            config = load_config(args.config)
            logger.info("Configuration loaded successfully.")
        except Exception as e:
            logger.error("Error loading config: %s", str(e))
            sys.exit(1)
    else:
        logger.info("No configuration file specified; using defaults.")
        config = {}

    template_preview = (PROMPT_PREFIX + "\n" + PROMPT_SUFFIX).split("\n")[:10]
    logger.info("Using prompt template (first 10 lines):\n%s", "\n".join(template_preview))

    if args.server:
        try:
            import uvicorn
        except ImportError:
            logger.error("uvicorn must be installed to run the server. Install it with pip install uvicorn")
            sys.exit(1)
        logger.info("Starting FastAPI server on port 8000 with provider hard-coded to openai/o3-mini ...")
        uvicorn.run("arxa.server:app", host="0.0.0.0", port=8000, reload=True)
        return

    if not (args.aid or args.pdf):
        parser.error("You must specify either -aid or -pdf when not running in --server mode.")

    papers_dir = config.get("papers_directory", tempfile.gettempdir())
    output_dir = config.get("output_directory", os.getcwd())

    pdf_path = None
    paper_info = {}

    if args.aid:
        aid = args.aid.strip()
        results = search_arxiv_by_id_list([aid])
        if not results:
            logger.error("Paper with arXiv ID %s not found.", aid)
            sys.exit(1)
        paper = results[0]
        paper_info = {
            "title": paper.title,
            "authors": [author.name for author in paper.authors],
            "abstract": paper.summary,
            "doi": paper.doi,
            "journal_ref": paper.journal_ref,
            "published": paper.published.isoformat() if paper.published else None,
            "arxiv_link": paper.entry_id,
        }
        pdf_filename = sanitize_filename(f"{aid}.pdf")
        pdf_path = os.path.join(papers_dir, pdf_filename)
        if not os.path.exists(pdf_path):
            logger.info("Downloading PDF for arXiv ID %s ...", aid)
            download_pdf_from_arxiv(paper, pdf_path)
        else:
            logger.info("Using existing PDF file at %s", pdf_path)
    else:
        pdf_path = args.pdf
        paper_info = {
            "title": os.path.basename(pdf_path),
            "authors": [],
            "abstract": "",
            "arxiv_link": "",
        }
        if not os.path.exists(pdf_path):
            logger.error("PDF file %s not found.", pdf_path)
            sys.exit(1)

    try:
        pdf_text = extract_text_from_pdf(pdf_path)
    except Exception as e:
        logger.error("Failed to extract text from PDF: %s", str(e))
        sys.exit(1)

    # Branch based on provider.
    if args.provider.lower() == "api.arxa.ai":
        try:
            import requests
        except ImportError:
            logger.error("The requests library is required for remote calls. Install it with pip install requests")
            sys.exit(1)

        # Calculate a unique aid based on pdf_text and paper_info.
        key_data = pdf_text + json.dumps(paper_info, sort_keys=True)
        aid = hashlib.sha256(key_data.encode("utf-8")).hexdigest()

        endpoint = "https://api.arxa.ai/generate-review"
        payload = {
            "pdf_text": pdf_text,
            "paper_info": paper_info,
            "provider": args.provider,
            "model": args.model,
            "aid": aid
        }
        logger.info("Sending review generation request to remote server at %s with aid: %s", endpoint, aid)
        response = requests.post(endpoint, json=payload)
        try:
            response.raise_for_status()
        except Exception as e:
            logger.error("Remote API call failed: %s", str(e))
            sys.exit(1)

        try:
            data = response.json()
            logger.debug("JSON response received: %s", json.dumps(data, indent=2))
        except Exception as e:
            logger.error("Failed to parse JSON response: %s", e)
            sys.exit(1)

        review = json_to_markdown(data)
    else:
        provider_normalized = args.provider.lower()
        if provider_normalized == "anthropic":
            try:
                from anthropic import Anthropic
            except ImportError:
                logger.error("Anthropic client library not installed.")
                sys.exit(1)
            anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
            if not anthropic_api_key:
                logger.error("ANTHROPIC_API_KEY environment variable not set.")
                sys.exit(1)
            llm_client = Anthropic(api_key=anthropic_api_key)
        elif provider_normalized == "openai":
            import openai
            openai_api_key = os.getenv("OPENAI_API_KEY")
            if not openai_api_key:
                logger.error("OPENAI_API_KEY environment variable not set.")
                sys.exit(1)
            openai.api_key = openai_api_key
            llm_client = openai
        elif provider_normalized == "ollama":
            llm_client = None
        elif provider_normalized == "deepseek":
            import openai
            deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
            if not deepseek_api_key:
                logger.error("DEEPSEEK_API_KEY environment variable not set.")
                sys.exit(1)
            openai.api_key = deepseek_api_key
            openai.api_base = "https://api.deepseek.com"
            llm_client = openai
        elif provider_normalized == "fireworks":
            llm_client = None
        else:
            logger.error("Unsupported provider: %s", args.provider)
            sys.exit(1)
        try:
            review = generate_research_review(
                pdf_text,
                paper_info,
                provider=args.provider,
                model=args.model,
                llm_client=llm_client
            )
        except Exception as e:
            logger.error("Error generating research review: %s", str(e))
            sys.exit(1)

    try:
        if args.output:
            with open(args.output, "w", encoding="utf-8") as f:
                f.write(review)
            logger.info("Review written to %s", args.output)
        else:
            logger.info("Generated review:\n%s", review)
    except Exception as e:
        logger.error("Failed to write review to file: %s", str(e))
        sys.exit(1)

    if not args.quiet:
        try:
            from rich.console import Console
            from rich.markdown import Markdown
            console = Console()
            console.rule("[bold green]Generated Research Review")
            md = Markdown(review)
            console.print(md)
            console.rule()
        except ImportError:
            print(review)

    if args.github:
        github_url = None
        try:
            from .repo_utils import extract_github_url
            github_url = extract_github_url(review)
        except Exception as e:
            logger.error("Error extracting GitHub URL: %s", str(e))
        if github_url:
            try:
                clone_repo(github_url, output_dir)
                logger.info("Repository cloned from %s", github_url)
            except Exception as e:
                logger.error("Error during GitHub cloning: %s", str(e))
        else:
            logger.info("No GitHub URL found in the review.")

if __name__ == "__main__":
    main()
