"""
JSONL-based executor: Export tables to JSONL files, then merge them.
Simpler approach that reduces CPU usage by avoiding in-memory joins.
"""

import json
import os
import tempfile
from datetime import datetime, date, time
from decimal import Decimal
from pathlib import Path
from .planner import LogicalPlan, JoinInfo
from .operators import ScanIterator, FilterIterator, ProjectIterator, LookupJoinIterator


def _json_serializer(obj):
    """
    Custom JSON serializer for objects not serializable by default json code.
    Handles datetime, date, time, Decimal, and other common database types.
    """
    if isinstance(obj, (datetime, date, time)):
        return obj.isoformat()
    elif isinstance(obj, Decimal):
        return float(obj)
    elif hasattr(obj, '__dict__'):
        return str(obj)
    raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")


def execute_plan_jsonl(
    plan,
    sources,
    source_metadata,
    output_dir=None,
    debug=False
):
    """
    Execute plan by exporting tables to JSONL files, then merging.
    
    Steps:
    1. Export root table to JSONL (with WHERE filter if applicable)
    2. Export each joined table to JSONL
    3. Merge JSONL files by reading and joining them
    4. Apply projection
    5. Yield results
    
    Args:
        plan: Logical execution plan
        sources: Dictionary mapping table names to source functions
        source_metadata: Dictionary with metadata about sources
        output_dir: Directory for temporary JSONL files (None = use temp dir)
        debug: Enable debug output
        
    Returns:
        Generator of result row dictionaries
    """
    if output_dir is None:
        output_dir = tempfile.mkdtemp(prefix="streaming_sql_")
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    if debug:
        print("=" * 60)
        print("JSONL-BASED EXECUTION MODE")
        print("=" * 60)
        print(f"Temporary directory: {output_dir}\n")
    
    try:
        # Step 1: Export root table to JSONL
        if debug:
            print(f"[STEP 1] Exporting root table '{plan.root_table}' to JSONL...")
        
        root_jsonl = output_path / f"{plan.root_table}.jsonl"
        root_count = _export_table_to_jsonl(
            sources[plan.root_table],
            plan.root_table,
            plan.root_alias,
            root_jsonl,
            plan.where_expr,  # Apply WHERE to root table
            debug
        )
        
        if debug:
            print(f"  Exported {root_count:,} rows to {root_jsonl.name}\n")
        
        # Step 2: Export joined tables to JSONL
        join_jsonls = []
        for i, join_info in enumerate(plan.joins, 1):
            if debug:
                print(f"[STEP {i+1}] Exporting joined table '{join_info.table}' to JSONL...")
            
            join_jsonl = output_path / f"{join_info.table}.jsonl"
            join_count = _export_table_to_jsonl(
                sources[join_info.table],
                join_info.table,
                join_info.alias,
                join_jsonl,
                None,  # No WHERE for joined tables (filtered during merge)
                debug
            )
            
            join_jsonls.append((join_info, join_jsonl))
            
            if debug:
                print(f"  Exported {join_count:,} rows to {join_jsonl.name}\n")
        
        # Step 3: Merge JSONL files
        if debug:
            print(f"[STEP {len(plan.joins) + 2}] Merging JSONL files...")
            print(f"  Merging {len(join_jsonls) + 1} files...\n")
        
        # Merge files sequentially
        merged_jsonl = root_jsonl
        for join_info, join_jsonl in join_jsonls:
            merged_jsonl = _merge_jsonl_files(
                merged_jsonl,
                join_jsonl,
                join_info,
                output_path,
                debug,
                memory_limit_mb=2000  # Increase to 2GB to use fast in-memory index more often
            )
        
        # Step 4: Apply projection and yield results
        if debug:
            print(f"[STEP {len(plan.joins) + 3}] Applying projection and yielding results...\n")
            print("-" * 60)
        
        return _yield_from_jsonl_with_projection(merged_jsonl, plan.projections, debug)
    
    finally:
        # Cleanup: Remove temporary files
        if not debug:
            # Only cleanup if not debugging (keep files for inspection)
            try:
                import shutil
                shutil.rmtree(output_dir)
            except Exception:
                pass


def _export_table_to_jsonl(
    source_fn,
    table_name,
    alias,
    output_file,
    where_expr,  # Expression for filtering
    debug
):
    """
    Export a table to JSONL file.
    Writes incrementally (append mode) to minimize memory usage.
    Each row is written immediately, not buffered in memory.
    """
    count = 0
    alias = alias or table_name
    
    # Open in write mode (creates new file, writes incrementally)
    with open(output_file, 'w', encoding='utf-8', buffering=8192) as f:
        iterator = ScanIterator(source_fn, table_name, alias, debug=False)
        
        # Apply filter if provided
        if where_expr:
            from .operators import FilterIterator
            iterator = FilterIterator(iterator, where_expr, debug=False)
        
        # Write each row immediately (incremental write, low memory)
        for row in iterator:
            f.write(json.dumps(row, ensure_ascii=False, default=_json_serializer) + '\n')
            count += 1
            
            if debug and count % 10000 == 0:
                print(f"    → Exported {count:,} rows...")
    
    return count


def _extract_column_from_key(key):
    """Extract column name from a key like 'alias.column'."""
    if "." in key:
        return key.split(".", 1)[1]
    return key


def _merge_jsonl_files(
    left_jsonl,
    right_jsonl,
    join_info,
    output_dir,
    debug,
    memory_limit_mb=500
):
    """
    Merge two JSONL files based on join condition.
    Returns path to merged JSONL file.
    """
    
    # Extract join key columns
    left_key_col = join_info.left_key  # e.g., "sp.sku"
    right_key_col = join_info.right_key  # e.g., "spo.concrete_sku"
    right_table_col = _extract_column_from_key(join_info.right_key)
    
    # Count right file size first to decide strategy
    if debug:
        print(f"    Analyzing {right_jsonl.name}...")
    
    right_count = 0
    with open(right_jsonl, 'r', encoding='utf-8') as f:
        for _ in f:
            right_count += 1
    
    # Estimate memory usage (rough: ~1KB per row average)
    estimated_mb = (right_count * 1024) / (1024 * 1024)
    
    if debug:
        print(f"    Right table: {right_count:,} rows (~{estimated_mb:.1f} MB estimated)")
    
    # Strategy: If right table is small enough, use in-memory index (fast)
    # Otherwise, use streaming approach (lower memory)
    use_memory_index = estimated_mb < memory_limit_mb
    
    # Initialize merged_jsonl variable (will be set in either branch)
    merged_jsonl = None
    
    if use_memory_index:
        # Fast path: Build index in memory (for small right tables)
        if debug:
            print(f"    Using in-memory index (table < {memory_limit_mb} MB)...")
        
        lookup_index = {}
        
        with open(right_jsonl, 'r', encoding='utf-8') as f:
            for line in f:
                right_row = json.loads(line)
                key_value = right_row.get(right_key_col)
                if key_value is None:
                    key_value = right_row.get(right_table_col)
                if key_value is not None:
                    if key_value not in lookup_index:
                        lookup_index[key_value] = []
                    lookup_index[key_value].append(right_row)
        
        if debug:
            print(f"    Index built: {right_count:,} rows, {len(lookup_index):,} unique keys")
        
        # Merge using index - write incrementally
        merged_jsonl = output_dir / f"merged_{left_jsonl.stem}_{right_jsonl.stem}.jsonl"
        merged_count = 0
        
        if debug:
            print(f"    Merging into {merged_jsonl.name}...")
        
        with open(left_jsonl, 'r', encoding='utf-8') as left_f, \
             open(merged_jsonl, 'w', encoding='utf-8') as merged_f:
            
            for line in left_f:
                left_row = json.loads(line)
                left_key_value = left_row.get(left_key_col)
                right_matches = lookup_index.get(left_key_value, [])
                
                if join_info.join_type == "INNER":
                    if not right_matches:
                        continue
                    for right_row in right_matches:
                        merged_row = {**left_row, **right_row}
                        merged_f.write(json.dumps(merged_row, ensure_ascii=False, default=_json_serializer) + '\n')
                        merged_count += 1
                else:  # LEFT JOIN
                    if not right_matches:
                        merged_f.write(json.dumps(left_row, ensure_ascii=False, default=_json_serializer) + '\n')
                        merged_count += 1
                    else:
                        for right_row in right_matches:
                            merged_row = {**left_row, **right_row}
                            merged_f.write(json.dumps(merged_row, ensure_ascii=False, default=_json_serializer) + '\n')
                            merged_count += 1
                
                if debug and merged_count % 10000 == 0:
                    print(f"    Merged {merged_count:,} rows...")
    
    else:
        # For very large files: Build index anyway but warn about memory usage
        # The O(n*m) streaming approach is too slow, so we use in-memory index even for large files
        # User can increase memory_limit_mb if they have memory constraints
        if debug:
            print(f"    WARNING: Right table is large ({estimated_mb:.1f} MB)")
            print(f"    Using in-memory index anyway (streaming mode is too slow)")
            print(f"    If you run out of memory, increase memory_limit_mb parameter")
        
        lookup_index = {}
        
        with open(right_jsonl, 'r', encoding='utf-8') as f:
            for line in f:
                right_row = json.loads(line)
                key_value = right_row.get(right_key_col)
                if key_value is None:
                    key_value = right_row.get(right_table_col)
                if key_value is not None:
                    if key_value not in lookup_index:
                        lookup_index[key_value] = []
                    lookup_index[key_value].append(right_row)
        
        if debug:
            print(f"    Index built: {right_count:,} rows, {len(lookup_index):,} unique keys")
        
        # Merge using index - write incrementally
        merged_jsonl = output_dir / f"merged_{left_jsonl.stem}_{right_jsonl.stem}.jsonl"
        merged_count = 0
        
        if debug:
            print(f"    Merging into {merged_jsonl.name}...")
        
        with open(left_jsonl, 'r', encoding='utf-8') as left_f, \
             open(merged_jsonl, 'w', encoding='utf-8') as merged_f:
            
            for line in left_f:
                left_row = json.loads(line)
                left_key_value = left_row.get(left_key_col)
                right_matches = lookup_index.get(left_key_value, [])
                
                if join_info.join_type == "INNER":
                    if not right_matches:
                        continue
                    for right_row in right_matches:
                        merged_row = {**left_row, **right_row}
                        merged_f.write(json.dumps(merged_row, ensure_ascii=False, default=_json_serializer) + '\n')
                        merged_count += 1
                else:  # LEFT JOIN
                    if not right_matches:
                        merged_f.write(json.dumps(left_row, ensure_ascii=False, default=_json_serializer) + '\n')
                        merged_count += 1
                    else:
                        for right_row in right_matches:
                            merged_row = {**left_row, **right_row}
                            merged_f.write(json.dumps(merged_row, ensure_ascii=False, default=_json_serializer) + '\n')
                            merged_count += 1
                
                if debug and merged_count % 10000 == 0:
                    print(f"    Merged {merged_count:,} rows...")
    
    if debug:
        print(f"    Merge complete: {merged_count:,} rows\n")
    
    # Return the merged file path
    if merged_jsonl is None:
        raise RuntimeError("Failed to create merged JSONL file")
    
    return merged_jsonl
    
def _yield_from_jsonl_with_projection(
    jsonl_file,
    projections,
    debug
):
    """
    Read JSONL file and apply SELECT projection.
    Streams through file line-by-line (low memory usage).
    Yields results incrementally, not buffered in memory.
    """
    from .operators import ProjectIterator
    from .evaluator import evaluate_expression
    from sqlglot import expressions as exp
    
    count = 0
    
    with open(jsonl_file, 'r', encoding='utf-8') as f:
        for line in f:
            row = json.loads(line)
            
            # Apply projection
            result = {}
            for expr in projections:
                if isinstance(expr, exp.Alias):
                    alias = expr.alias
                    value = evaluate_expression(expr.this, row)
                    result[alias] = value
                elif isinstance(expr, exp.Column):
                    col_name = f"{expr.table}.{expr.name}" if expr.table else expr.name
                    if col_name in row:
                        result[expr.name] = row[col_name]
                    else:
                        raise KeyError(f"Column {col_name} not found")
                else:
                    value = evaluate_expression(expr, row)
                    result[str(expr)] = value
            
            count += 1
            if debug and count % 10000 == 0:
                print(f"  Yielded {count:,} result rows...")
            
            yield result

