import time
import json
from rich.console import Console
from rich.panel import Panel
from prompt_toolkit import prompt
from prompt_toolkit.formatted_text import FormattedText
from loguru import logger
import byzerllm
from typing import Optional
from autocoder.common.auto_coder_lang import get_message
from autocoder.common.memory_manager import save_to_memory_file
from autocoder.common.utils_code_auto_generate import stream_chat_with_continue
from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
from autocoder.common.printer import Printer
from autocoder.common.tokens import count_string_tokens as count_tokens
from autocoder.privacy.model_filter import ModelPathFilter
from autocoder.common.result_manager import ResultManager
from autocoder.events.event_manager_singleton import get_event_manager
from autocoder.events import event_content as EventContentCreator
from autocoder.events.event_types import EventMetadata
from autocoder.common.mcp_tools.server import get_mcp_server
from autocoder.common.mcp_tools.types import McpRequest
from autocoder.utils.llms import get_llm_names
from autocoder.run_context import get_run_context, RunMode
from autocoder.common.action_yml_file_manager import ActionYmlFileManager
from autocoder.common.conversations.get_conversation_manager import get_conversation_manager


class ChatAgent:
    def __init__(self, args, llm, raw_args):
        self.args = args
        self.llm = llm
        self.raw_args = raw_args
        self.console = Console()
        self.result_manager = ResultManager()
        self.conversation_manager = get_conversation_manager()
        # 生成命名空间用于会话隔离
        self.namespace = self._generate_namespace()

    def _generate_namespace(self) -> Optional[str]:
        """
        生成命名空间，用于会话隔离
        
        Returns:
            str: 基于项目路径的命名空间
        """
        return None

    def run(self):
        """执行 chat 命令的主要逻辑"""
        # 统一格式
        # {"command1": {"args": ["arg1", "arg2"], "kwargs": {"key1": "value1", "key2": "value2"}}}
        if isinstance(self.args.action, dict):
            commands_info = self.args.action
        else:
            commands_info = {}
            for command in self.args.action:
                commands_info[command] = {}

        # 处理新会话
        if self.args.new_session:
            self._handle_new_session()
            if not self.args.query or (self.args.query_prefix and self.args.query == self.args.query_prefix) or (self.args.query_suffix and self.args.query == self.args.query_suffix):
                return

        # 确保有当前会话，如果没有则创建
        current_conversation_id = self.conversation_manager.get_current_conversation_id(self.namespace)
        if not current_conversation_id:
            current_conversation_id = self.conversation_manager.create_conversation(
                name="Chat Session",
                description="Auto-coder chat session"
            )
            self.conversation_manager.set_current_conversation(current_conversation_id, self.namespace)

        # 添加用户消息到当前会话
        self.conversation_manager.append_message_to_current(
            role="user",
            content=self.args.query,
            namespace=self.namespace
        )

        # 获取聊天模型
        if self.llm.get_sub_client("chat_model"):
            chat_llm = self.llm.get_sub_client("chat_model")
        else:
            chat_llm = self.llm

        # 构建对话上下文
        loaded_conversations = self._build_conversations(commands_info)

        # 处理人工模型模式
        if get_run_context().mode != RunMode.WEB and self.args.human_as_model:
            return self._handle_human_as_model(loaded_conversations, commands_info)

        # 计算耗时
        start_time = time.time()
        commit_file_name = None

        # 根据命令类型处理不同的响应
        v = self._get_response(commands_info, loaded_conversations, chat_llm)
        if isinstance(v, tuple):
            v, commit_file_name = v

        # 输出响应
        model_name = ",".join(get_llm_names(chat_llm))
        assistant_response, last_meta = stream_out(
            v,
            request_id=self.args.request_id,
            console=self.console,
            model_name=model_name,
            args=self.args
        )

        self.result_manager.append(content=assistant_response, meta={
            "action": "chat",
            "input": {
                "query": self.args.query
            }
        })

        # 处理学习命令的特殊逻辑
        if "learn" in commands_info and commit_file_name:
            self._handle_learn_command(commit_file_name, assistant_response)

        # 打印统计信息
        if last_meta:
            self._print_stats(last_meta, start_time, model_name)

        # 更新聊天历史 - 添加助手回复到当前会话
        self.conversation_manager.append_message_to_current(
            role="assistant",
            content=assistant_response,
            namespace=self.namespace
        )

        # 处理后续命令
        self._handle_post_commands(commands_info, assistant_response)

    def _handle_new_session(self):
        """处理新会话逻辑"""
        # 创建新的对话会话
        conversation_id = self.conversation_manager.create_conversation(
            name="New Chat Session",
            description=f"Chat session started at {time.strftime('%Y-%m-%d %H:%M:%S')}"
        )
        
        # 设置为当前会话
        self.conversation_manager.set_current_conversation(conversation_id, self.namespace)

        self.result_manager.add_result(content=get_message("new_session_started"), meta={
            "action": "chat",
            "input": {
                "query": self.args.query
            }
        })
        self.console.print(
            Panel(
                get_message("new_session_started"),
                title="Session Status",
                expand=False,
                border_style="green",
            )
        )

    def _get_current_conversation_messages(self):
        """获取当前会话的消息列表"""
        current_conversation_id = self.conversation_manager.get_current_conversation_id(self.namespace)
        if current_conversation_id:
            messages = self.conversation_manager.get_messages(current_conversation_id)
            return [{"role": msg["role"], "content": msg["content"]} for msg in messages]
        return []

    def _build_conversations(self, commands_info):
        """构建对话上下文"""
        source_count = 0
        pre_conversations = []
        context_content = self.args.context if self.args.context else ""
        
        if self.args.context:
            try:
                context = json.loads(self.args.context)
                if "file_content" in context:
                    context_content = context["file_content"]
            except:
                pass

            pre_conversations.append(
                {
                    "role": "user",
                    "content": f"请阅读下面的代码和文档：\n\n <files>\n{context_content}\n</files>",
                },
            )
            pre_conversations.append(
                {"role": "assistant", "content": "read"})
            source_count += 1

        # 构建索引和过滤文件
        if "no_context" not in commands_info:            
            from autocoder.index.entry import build_index_and_filter_files
            from autocoder.pyproject import PyProject
            from autocoder.tsproject import TSProject
            from autocoder.suffixproject import SuffixProject
            from autocoder.default_project import DefaultProject

            if self.args.project_type == "ts":
                pp = TSProject(args=self.args, llm=self.llm)
            elif self.args.project_type == "py":
                pp = PyProject(args=self.args, llm=self.llm)
            elif not self.args.project_type or self.args.project_type == "*":
                pp = DefaultProject(args=self.args, llm=self.llm)
            else:
                pp = SuffixProject(args=self.args, llm=self.llm, file_filter=None)
            pp.run()
            sources = pp.sources

            # 应用模型过滤器
            chat_llm = self.llm.get_sub_client("chat_model") or self.llm
            model_filter = ModelPathFilter.from_model_object(chat_llm, self.args)
            filtered_sources = []
            printer = Printer()
            for source in sources:
                if model_filter.is_accessible(source.module_name):
                    filtered_sources.append(source)
                else:
                    printer.print_in_terminal("index_file_filtered",
                                              style="yellow",
                                              file_path=source.module_name,
                                              model_name=",".join(get_llm_names(chat_llm)))

            s = build_index_and_filter_files(
                llm=self.llm, args=self.args, sources=filtered_sources).to_str()

            if s:
                pre_conversations.append(
                    {
                        "role": "user",
                        "content": f"请阅读下面的代码和文档：\n\n <files>\n{s}\n</files>",
                    }
                )
                pre_conversations.append(
                    {"role": "assistant", "content": "read"})
                source_count += 1

        # 获取当前会话的历史消息
        current_messages = self._get_current_conversation_messages()
        loaded_conversations = pre_conversations + current_messages
        return loaded_conversations

    def _handle_human_as_model(self, loaded_conversations, commands_info):
        """处理人工模型模式"""
        @byzerllm.prompt()
        def chat_with_human_as_model(
            source_codes, pre_conversations, last_conversation
        ):
            """                    
            {% if source_codes %}                    
            {{ source_codes }}
            {% endif %}                    

            {% if pre_conversations %}
            下面是我们之间的历史对话，假设我是A，你是B。
            <conversations>
            {% for conv in pre_conversations %}
            {{ "A" if conv.role == "user" else "B" }}: {{ conv.content }}
            {% endfor %}
            </conversations>
            {% endif %}


            参考上面的文件以及历史对话，回答用户的问题。
            用户的问题: {{ last_conversation.content }}
            """

        source_count = 0
        if self.args.context:
            source_count += 1
        if "no_context" not in commands_info:
            source_count += 1

        source_codes_conversations = loaded_conversations[0: source_count * 2]
        source_codes = ""
        for conv in source_codes_conversations:
            if conv["role"] == "user":
                source_codes += conv["content"]

        chat_content = chat_with_human_as_model.prompt(
            source_codes=source_codes,
            pre_conversations=loaded_conversations[source_count * 2: -1],
            last_conversation=loaded_conversations[-1],
        )

        with open(self.args.target_file, "w", encoding="utf-8") as f:
            f.write(chat_content)

        try:
            import pyperclip
            pyperclip.copy(chat_content)
            self.console.print(
                Panel(
                    get_message("chat_human_as_model_instructions"),
                    title="Instructions",
                    border_style="blue",
                    expand=False,
                )
            )
        except Exception:
            logger.warning(get_message("clipboard_not_supported"))
            self.console.print(
                Panel(
                    get_message("human_as_model_instructions_no_clipboard"),
                    title="Instructions",
                    border_style="blue",
                    expand=False,
                )
            )

        lines = []
        while True:
            line = prompt(FormattedText([("#00FF00", "> ")]), multiline=False)
            line_lower = line.strip().lower()
            if line_lower in ["eof", "/eof"]:
                break
            elif line_lower in ["/clear"]:
                lines = []
                print("\033[2J\033[H")  # Clear terminal screen
                continue
            elif line_lower in ["/break"]:
                raise Exception("User requested to break the operation.")
            lines.append(line)

        result = "\n".join(lines)

        self.result_manager.append(content=result,
                                   meta={"action": "chat", "input": {
                                       "query": self.args.query
                                   }})

        # 更新当前会话的历史记录
        self.conversation_manager.append_message_to_current(
            role="assistant",
            content=result,
            namespace=self.namespace
        )

        if "save" in commands_info:
            # 获取当前会话消息用于保存
            current_messages = self._get_current_conversation_messages()
            save_to_memory_file(ask_conversation=current_messages,
                                query=self.args.query,
                                response=result)
            printer = Printer()
            printer.print_in_terminal("memory_save_success")
        return {}

    def _get_response(self, commands_info, loaded_conversations, chat_llm):
        """根据命令类型获取响应"""
        if "rag" in commands_info:
            from autocoder.rag.rag_entry import RAGFactory
            self.args.enable_rag_search = True
            self.args.enable_rag_context = False
            rag = RAGFactory.get_rag(llm=chat_llm, args=self.args, path="")
            response = rag.stream_chat_oai(conversations=loaded_conversations)[0]
            return (item for item in response)

        elif "mcp" in commands_info:
            mcp_server = get_mcp_server()
            pos_args = commands_info["mcp"].get("args", [])
            final_query = pos_args[0] if pos_args else self.args.query
            response = mcp_server.send_request(
                McpRequest(
                    query=final_query,
                    model=self.args.inference_model or self.args.model,
                    product_mode=self.args.product_mode
                )
            )
            return [[response.result, None]]

        elif "review" in commands_info:
            from autocoder.agent.auto_review_commit import AutoReviewCommit
            reviewer = AutoReviewCommit(llm=chat_llm, args=self.args)
            pos_args = commands_info["review"].get("args", [])
            final_query = pos_args[0] if pos_args else self.args.query
            kwargs = commands_info["review"].get("kwargs", {})
            commit_id = kwargs.get("commit", None)
            return reviewer.review_commit(query=final_query, conversations=loaded_conversations, commit_id=commit_id)

        elif "learn" in commands_info:
            from autocoder.agent.auto_learn_from_commit import AutoLearnFromCommit
            learner = AutoLearnFromCommit(llm=chat_llm, args=self.args)
            pos_args = commands_info["learn"].get("args", [])
            final_query = pos_args[0] if pos_args else self.args.query
            return learner.learn_from_commit(query=final_query, conversations=loaded_conversations)

        else:
            # 预估token数量
            dumped_conversations = json.dumps(loaded_conversations, ensure_ascii=False)
            estimated_input_tokens = count_tokens(dumped_conversations)
            printer = Printer()
            printer.print_in_terminal("estimated_chat_input_tokens", style="yellow",
                                      estimated_input_tokens=estimated_input_tokens)

            return stream_chat_with_continue(
                llm=chat_llm,
                conversations=loaded_conversations,
                llm_config={},
                args=self.args
            )

    def _handle_learn_command(self, commit_file_name, assistant_response):
        """处理学习命令的特殊逻辑"""
        if commit_file_name:
            # 使用 ActionYmlFileManager 更新 YAML 文件
            action_manager = ActionYmlFileManager(self.args.source_dir)
            if not action_manager.update_yaml_field(commit_file_name, 'how_to_reproduce', assistant_response):
                printer = Printer()
                printer.print_in_terminal("yaml_save_error", style="red", yaml_file=commit_file_name)

    def _print_stats(self, last_meta, start_time, model_name):
        """打印统计信息"""
        elapsed_time = time.time() - start_time
        printer = Printer()
        speed = last_meta.generated_tokens_count / elapsed_time

        # Get model info for pricing
        from autocoder.utils import llms as llm_utils
        model_info = llm_utils.get_model_info(model_name, self.args.product_mode) or {}
        input_price = model_info.get("input_price", 0.0) if model_info else 0.0
        output_price = model_info.get("output_price", 0.0) if model_info else 0.0

        # Calculate costs
        input_cost = (last_meta.input_tokens_count * input_price) / 1000000  # Convert to millions
        output_cost = (last_meta.generated_tokens_count * output_price) / 1000000  # Convert to millions

        printer.print_in_terminal("stream_out_stats",
                                  model_name=model_name,
                                  elapsed_time=elapsed_time,
                                  first_token_time=last_meta.first_token_time,
                                  input_tokens=last_meta.input_tokens_count,
                                  output_tokens=last_meta.generated_tokens_count,
                                  input_cost=round(input_cost, 4),
                                  output_cost=round(output_cost, 4),
                                  speed=round(speed, 2))
        
        get_event_manager(self.args.event_file).write_result(
            EventContentCreator.create_result(content=EventContentCreator.ResultTokenStatContent(
                model_name=model_name,
                elapsed_time=elapsed_time,
                input_tokens=last_meta.input_tokens_count,
                output_tokens=last_meta.generated_tokens_count,
                input_cost=round(input_cost, 4),
                output_cost=round(output_cost, 4),
                speed=round(speed, 2)
            )).to_dict(), metadata=EventMetadata(
                action_file=self.args.file
            ).to_dict())

    def _handle_post_commands(self, commands_info, assistant_response):
        """处理后续命令"""
        if "copy" in commands_info:
            # copy assistant_response to clipboard
            import pyperclip
            try:
                pyperclip.copy(assistant_response)
            except:
                print("pyperclip not installed or clipboard is not supported, instruction will not be copied to clipboard.")

        if "save" in commands_info:
            # 获取当前会话消息用于保存
            current_messages = self._get_current_conversation_messages()
            tmp_dir = save_to_memory_file(ask_conversation=current_messages,
                                          query=self.args.query,
                                          response=assistant_response)
            printer = Printer()
            printer.print_in_terminal("memory_save_success", style="green", path=tmp_dir)

            if len(commands_info["save"]["args"]) > 0:
                # 保存到指定文件
                with open(commands_info["save"]["args"][0], "w", encoding="utf-8") as f:
                    f.write(assistant_response)
