from datetime import datetime
import json
from pathlib import Path
import traceback
import functools
import time
import os
import uuid

from loguru import logger
from .constants import (
    _thread_local, 
    _process_default_session_id
)
from .encoder import GrammarEncoder
from typing import Any, Callable, Dict, Optional, List, Tuple, TypeVar, Generic
# New helper functions for LLM tracking

def manage_session_id(session_id):
    """Manage thread-local session ID for LLM calls"""
    previous_session_id = getattr(_thread_local, 'session_id', None)
    if session_id:
        _thread_local.session_id = session_id
        logger.debug(f"Using provided session ID for LLM call: {session_id}")
    return previous_session_id

def restore_session_id(previous_session_id):
    """Restore previous session ID after an LLM call"""
    if previous_session_id:
        _thread_local.session_id = previous_session_id
        logger.debug(f"Restored previous session ID: {previous_session_id}")

def get_or_create_session_id():
    """Get existing session ID or create a new default one"""
    session_id = getattr(_thread_local, 'session_id', None)
    if session_id is None and os.getenv("LLM_AUTO_SESSION", "1") == "1":
        session_id = _process_default_session_id
        _thread_local.session_id = session_id
    return session_id

def create_base_log_entry(provider, request_args):
    """
    Create a base log entry with standardized structure for all providers.
    
    Args:
        provider (str): The LLM provider name (openai, anthropic, bedrock)
        request_args (dict): The provider-specific request arguments
        
    Returns:
        dict: A standardized log entry object with base fields
    """
    # Generate a unique log ID
    log_id = str(uuid.uuid4())
    
    # Get or create session ID
    if hasattr(_thread_local, 'session_id') and _thread_local.session_id:
        session_id = _thread_local.session_id
    else:
        session_id = _process_default_session_id
    
    # Create ISO 8601 timestamp
    timestamp = datetime.utcnow().isoformat()
    
    # Extract and standardize request fields
    standard_request = {
        # Core required fields
        "model": request_args.get("model") or request_args.get("modelId", "unknown"),
        "messages": request_args.get("messages", []),
        
        # Standard parameters (all providers)
        "temperature": request_args.get("temperature"),
        "max_tokens": request_args.get("max_tokens") or request_args.get("maxTokens"),
        "top_p": request_args.get("top_p") or request_args.get("topP"),
        
        # OpenAI-specific parameters (may be null for other providers)
        "frequency_penalty": request_args.get("frequency_penalty"),
        "presence_penalty": request_args.get("presence_penalty"),
        "stop": request_args.get("stop"),
        "n": request_args.get("n"),
    }
    
    # Add tools if present
    if "tools" in request_args:
        standard_request["tools"] = request_args.get("tools")
    
    # Add tool_choice if present
    if "tool_choice" in request_args:
        standard_request["tool_choice"] = request_args.get("tool_choice")
    
    # Create base log entry
    log_entry = {
        "log_id": log_id,
        "session_id": session_id,
        "timestamp": timestamp,
        "provider": provider,
        "request": standard_request,
        
        # Empty fields to be filled in by specific provider logging functions
        "response": "",
        "usage": {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "total_tokens": 0,
            "token_details": {
                "cached_tokens": None,
                "audio_tokens": None,
                "reasoning_tokens": None,
                "accepted_prediction_tokens": None,
                "rejected_prediction_tokens": None
            }
        },
        "duration": 0,
        "success": False,
        "template_substitutions": {}
    }
    
    return log_entry

def create_generic_method_wrapper(original_method, provider, log_function):
    """Create a generic wrapper for LLM API methods"""
    @functools.wraps(original_method)
    def wrapped_method(*args, **kwargs):

        # Extract session_id if provided directly in kwargs
        session_id = kwargs.pop('session_id', None)
        previous_session_id = manage_session_id(session_id)
        
        start_time = time.perf_counter()
        success = True
        response_data = None
        
        try:
            response_data = original_method(*args, **kwargs)
            logger.debug(f"Original method for {provider} completed successfully")
            return response_data
        except Exception as e:
            success = False
            response_data = {"error": str(e), "traceback": traceback.format_exc()}
            logger.error(f"LLM call to {provider} failed: {e}")
            raise
        finally:
            duration = time.perf_counter() - start_time
            try:
                # Debug logging
                logger.debug(f"Calling log function for {provider} with success={success}") 
                # Call the provider-specific log function
                log_function(provider, kwargs, response_data, duration, success)
            except Exception as e:
                logger.error(f"Error during LLM tracking for {provider}: {e}")
                logger.error(traceback.format_exc())
            
            restore_session_id(previous_session_id)
    
    # Mark the patched function so we don't patch it again
    wrapped_method._llm_tracker_patched = True
    return wrapped_method

def create_async_method_wrapper(original_async_method, provider, log_function):
    """Create a generic async wrapper for LLM API methods"""
    @functools.wraps(original_async_method)
    async def async_wrapped_method(*args, **kwargs):
        # Extract session_id if provided directly in kwargs
        session_id = kwargs.pop('session_id', None)
        previous_session_id = manage_session_id(session_id)
        
        start_time = time.perf_counter()
        success = True
        response_data = None
        
        try:
            response_data = await original_async_method(*args, **kwargs)
            return response_data
        except Exception as e:
            success = False
            response_data = {"error": str(e), "traceback": traceback.format_exc()}
            logger.error(f"Async LLM call to {provider} failed: {e}")
            raise
        finally:
            duration = time.perf_counter() - start_time
            try:
                # Call the provider-specific log function
                log_function(provider, kwargs, response_data, duration, success)
            except Exception as e:
                logger.error(f"Error during LLM tracking for {provider}: {e}")
                logger.error(traceback.format_exc())
            
            restore_session_id(previous_session_id)
    
    # Mark the patched function so we don't patch it again
    async_wrapped_method._llm_tracker_patched = True
    return async_wrapped_method

def add_stack_trace_to_log(log_entry, stack_info):
    """Add stack trace to log entry if enabled"""
    if os.getenv("LLM_TRACKER_STACK_TRACE", "1") == "1":
        log_entry["stack_trace"] = stack_info
    return log_entry