#     The Certora Prover
#     Copyright (C) 2025  Certora Ltd.
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation, version 3 of the License.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.
#
#     You should have received a copy of the GNU General Public License
#     along with this program.  If not, see <https://www.gnu.org/licenses/>.

import re
import time
import hashlib
import sys
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional, Set, Tuple
try:
    from typing import TypeAlias
except ImportError:
    from typing_extensions import TypeAlias

scripts_dir_path = Path(__file__).parent.parent.resolve()  # containing directory
sys.path.insert(0, str(scripts_dir_path))

from CertoraProver.certoraBuildDataClasses import ContractInSDC
from CertoraProver import erc7201
from Shared import certoraUtils as Util
from CertoraProver.certoraBuildDataClasses import SDC

NameSpacedStorage: TypeAlias = Tuple[str, str]
NewStorageFields = List[Dict[str, Any]]
NewStorageTypes  = Dict[str, Any]
NewStorageInfo   = Tuple[NewStorageFields, NewStorageTypes]

storage_extension_logger = logging.getLogger("storage_extension")


def erc7201_of_node(n: Dict[str, Any]) -> Optional[NameSpacedStorage]:
    """
    If n is a StructDefinition node, try and extract the @custom:storage-location
    structured documentation, if it exists.

    @returns (type, namespace) where 'type' is the name of the annotated type and
                'namespace' is the namespace string
    """
    if n.get("nodeType") != "StructDefinition":
        return None
    typeName = n.get("canonicalName")
    doc = n.get("documentation")
    if doc is None or doc.get("nodeType") != "StructuredDocumentation" or typeName is None:
        return None
    storage_location_regex = r'@custom:storage-location erc7201:([a-zA-Z.0-9]+)'
    match = re.search(storage_location_regex, doc.get("text"))
    if match is None:
        return None
    ns = match.group(1)
    return (typeName, ns)


def generate_harness_name(original_file: str) -> str:
    """
    Generate a unique name for the harness contract based on the original file name.
    The name is generated by hashing the original file name and appending a suffix to ensure uniqueness.
    Args:
        original_file (str): The path to the original file used to generate the harness name.
    Returns:
        str: A unique name for the harness contract.
    """
    stem = Path(original_file).stem
    # 1) Compute an 8-hex salt from path+time
    digest = hashlib.sha1(f"{original_file}{time.time()}".encode()).hexdigest()[:8]
    suffix = f"_{digest}_Harness"
    # 2) Reserve space for suffix so whole name ≤31 chars
    max_stem = 31 - len(suffix)
    if len(stem) > max_stem:
        stem = stem[:max_stem]
    # 3) Build, sanitize, and ensure start‐with‐letter
    raw = f"{stem}{suffix}"
    name = re.sub(r'[^A-Za-z0-9_]', '_', raw)
    if not name[0].isalpha():
        name = f"H{name[1:]}"
    return name


def get_next_file_index(self_file_to_sdc_name: Dict[Path, str], max_index: int = 1000) -> int:
    """
    Gets the next available file index for temporary files to avoid naming conflicts.

    This method examines the existing SDC names in `file_to_sdc_name` dictionary and
    extracts numeric indices from them. It identifies the highest index currently in use
    and returns that value plus one. If no valid indices are found or if an error occurs,
    a default value of 1000 is returned.

    The method assumes SDC names follow the format "some_prefix_NUMBER", where NUMBER
    is an integer that can be extracted from the last segment after splitting by underscore.
    Args:
        self_file_to_sdc_name (Dict[Path, str]): A dictionary mapping file paths to their SDC names.
        max_index (int): The maximum index to return if no valid indices are found. Default is 1000.
    Returns:
        int: The next available file index, or a default value of 1000 if no valid indices are found.
    """
    try:
        # If file_to_sdc_name is empty, return a default starting index
        if not self_file_to_sdc_name:
            return max_index

        indices = []
        for sdc_name in self_file_to_sdc_name.values():
            parts = sdc_name.split("_")
            if len(parts) > 1:  # Make sure there's at least one underscore
                try:
                    # Try to convert the last part to an integer
                    index = int(parts[-1])
                    indices.append(index)
                except ValueError:
                    # Skip if the last part isn't a number
                    continue

        # If we found no valid indices, return a default value
        if not indices:
            storage_extension_logger.debug(f"No valid indices found in file_to_sdc_name, using default value of {max_index}")
            return max_index

        return max(indices) + 1
    except Exception:
        # Fallback in case of any unexpected errors
        storage_extension_logger.debug(f"Error determining next file index, using default value of {max_index}")
        return max_index


def write_harness_contract(tmp_file: Any,
                           harness_name: str,
                           ns_storage: Set[NameSpacedStorage]) -> Dict[str, str]:
    """
    Write the harness contract with dummy fields to the temporary file.

    Args:
        tmp_file: The temporary file to write to
        harness_name: Name of the harness contract
        ns_storage: Set of namespace storage declarations

    Returns:
        Dict[str, str]: Mapping from variable names to their slots
    """
    tmp_file.write(f"contract {harness_name} {{\n")

    # Map from variable names to their slots
    var_to_slot = {}

    # Add dummy fields for each namespaced storage
    for type_name, namespace in ns_storage:
        # Create a variable name by replacing dots with underscores and appending the hash
        # This is to ensure the variable name is unique and valid in Solidity
        var_name = namespace.replace('.', '_')

        # Calculate the slot using ERC-7201 formula
        # UTF-8 is the standard encoding for Ethereum and Solidity
        slot = str(erc7201.erc7201(namespace.encode('utf-8')))

        var_to_slot[var_name] = slot
        tmp_file.write(f"\t{type_name} {var_name};\n")

    tmp_file.write("}\n")
    return var_to_slot


def extract_harness_contract_layout(sdcs: List[SDC], harness_name: str) -> Dict[str, Any]:
    """
    Extract the storage layout of the harness contract.

    Args:
        sdcs: List of SDCs containing the compiled contracts
        harness_name: Name of the harness contract
    Returns:
        Dict[str, Any]: The storage layout of the harness contract
    """

    # Search through all SDC's to find the correct contract
    harness_contract = None
    for sdc in sdcs:
        harness_contract = sdc.find_contract(harness_name)
        if harness_contract:
            break

    if not harness_contract:
        raise RuntimeError(f"Could not find harness contract {harness_name} in compiled output")
    # Extract the storage layout
    layout = harness_contract.storage_layout
    if not layout or 'storage' not in layout:
        raise RuntimeError(f"Invalid storage layout for harness contract {harness_name}")
    return layout


def remapped_fields_from_layout(layout: Dict[str, Any], var_to_slot: Dict[str, str]) -> NewStorageFields:
    """
    Remap the fields in the storage layout according to the variable to slot mapping.

    Args:
        layout: The storage layout of the harness contract
        var_to_slot: Mapping from variable names to their slots

    Returns:
        List[Dict[str, Any]]: A list of remapped fields with updated slot information
    """
    remapped_fields = []
    for storage_item in layout['storage']:
        cloned_item = storage_item.copy()
        var_name = cloned_item["label"]
        if var_name in var_to_slot:
            cloned_item["slot"] = var_to_slot[var_name]
            remapped_fields.append(cloned_item)
        else:
            storage_extension_logger.warning(f"Skipping adding variable {var_name} not found in variable to slot mapping")

    remapped_fields.sort(key=lambda f: int(f["slot"]))
    return remapped_fields


def get_namespace_storage_from_ast(def_node: Dict[str, Any]) -> Set[NameSpacedStorage]:
    """
    Extracts namespaced storage information from the AST nodes.

    Args:
        def_node: The AST node representing the contract definition.

    Returns:
        Set[NameSpacedStorage]: A set of namespaced storage information.
    """
    ns_storage: Set[NameSpacedStorage] = set()
    nodes = def_node.get("nodes")
    if not nodes:
        # No nodes found in the contract definition
        return ns_storage
    for n in nodes:
        sinfo = erc7201_of_node(n)
        if sinfo is not None:
            storage_extension_logger.debug(f"Found namespaced storage: {sinfo}")
            ns_storage.add(sinfo)
    return ns_storage


def apply_extensions(target_contract: ContractInSDC,
                     extensions: Set[str],
                     to_add: Dict[str, NewStorageInfo]) -> None:
    """
    Apply the fields from each extension to the target contract,
    @param target_contract contract to which to apply extensions
    @param extensions set of extension contract names
    @param to_add maps extension name in extensions to (storage layouts, new types)
    """
    storage_layout = target_contract.storage_layout
    # Check if the target contract has a storage layout
    if not storage_layout:
        storage_extension_logger.warning(f"Target contract {target_contract.name} has no storage layout")
        return

    # Check if the target contract has a storage layout with 'storage' key
    if "storage" not in storage_layout:
        storage_extension_logger.warning(f"Target contract {target_contract.name} storage layout does not contain 'storage' key")
        storage_layout["storage"] = []

    # Check if the target contract has a storage layout with 'types' key
    if "types" not in storage_layout:
        storage_extension_logger.warning(f"Target contract {target_contract.name} storage layout does not contain 'types' key")
        storage_layout["types"] = {}

    target_slots = {storage["slot"] for storage in storage_layout["storage"]}
    target_vars = {storage["label"] for storage in storage_layout["storage"]}
    # Keep track of slots we've added, and error if we
    # find two extensions extending the same slot
    added_slots: Dict[str, str] = {}
    added_vars: Dict[str, str] = {}
    for ext in extensions:
        # Check if the extension is in the to_add mapping
        if ext not in to_add:
            storage_extension_logger.warning(f"Extension {ext} not found in to_add mapping")
            continue
        (new_fields, new_types) = to_add[ext]
        for field in new_fields:
            # See if any of the new fields is a slot or variable name we've already added
            slot = field["slot"]
            var = field["label"]
            validate_new_fields(
                target_contract,
                ext,
                slot,
                var,
                added_slots,
                added_vars,
                target_slots,
                target_vars
            )

            added_slots[slot] = ext
            added_vars[var] = ext

        # Add the fields to the storage layout
        storage_layout["storage"].extend(new_fields)
        storage_extension_logger.debug(f"Added {len(new_fields)} fields from extension {ext} to contract {target_contract.name}: {[field['label'] for field in new_fields]}")
        for (new_id, new_ty) in new_types.items():
            if new_id in storage_layout["types"]:
                continue
            storage_layout["types"][new_id] = new_ty


def validate_new_fields(
        target_contract: ContractInSDC,
        ext: str,
        slot: str,
        var: str,
        added_slots: Dict[str, str],
        added_vars: Dict[str, str],
        target_slots: Set[str],
        target_vars: Set[str]) -> None:
    """
    Validate that the new fields being added to the target contract
    do not conflict with existing fields or variables.

    Args:
        target_contract: The target contract to which the fields are being added
        ext: The name of the extension contract
        slot: The slot being added
        var: The variable being added
        added_slots: Dictionary of slots already added
        added_vars: Dictionary of variables already added
        target_slots: Set of slots in the target contract
        target_vars: Set of variables in the target contract
    """

    if slot in added_slots:
        seen = added_slots[slot]
        raise Util.CertoraUserInputError(f"Slot {slot} added to {target_contract.name} by {ext} was already added by {seen}")
    if var in added_vars:
        seen = added_vars[var]
        raise Util.CertoraUserInputError(f"Var '{var}' added to {target_contract.name} by {ext} was already added by {seen}")
    if slot in target_slots:
        raise Util.CertoraUserInputError(f"Slot {slot} added to {target_contract.name} by {ext} is already mapped by {target_contract.name}")
    if var in target_vars:
        raise Util.CertoraUserInputError(f"Var '{var}' added to {target_contract.name} by {ext} is already declared by {target_contract.name}")
