import ast
import importlib
import inspect
import json
import os
import sys
import tempfile
from types import ModuleType
from typing import Any, Dict, List, Optional, Set, TextIO

import astor

from beam import App

BEAM_MODULE_NAME = "beam"


class AppExtractor(ast.NodeVisitor):
    def __init__(self):
        self.beam_imports: Set[str] = set()
        self.output_module_source: List[str] = []
        self.dependencies: Set[str] = set()
        self.nodes_to_keep: Set[ast.AST] = set()
        self.unresolved_nodes: Set[ast.AST] = set()
        self.name_parents: Dict[str, ast.AST] = dict()
        self.current_node: Any = None

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        self._check_node(node)
        self.generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        if node.module == BEAM_MODULE_NAME:
            self.beam_imports.update(alias.name for alias in node.names)
            self.nodes_to_keep.add(node)
            return

        # Store the import as a dependency to check later if it's used by Beam
        self.unresolved_nodes.add(node)

    def visit_Import(self, node: ast.Import) -> None:
        for alias in node.names:
            if alias.name == BEAM_MODULE_NAME:
                self.beam_imports.add(alias.name)
                self.nodes_to_keep.add(node)
                self.generic_visit(node)
                return

        # Store the import as a dependency to check later if it's used by Beam
        self.unresolved_nodes.add(node)
        self.generic_visit(node)

    def visit_Attribute(self, node: ast.Attribute) -> None:
        if isinstance(node.value, ast.Name):
            if node.value.id in self.dependencies:
                self.nodes_to_keep.add(node)

        self.generic_visit(node)

    def visit_Name(self, node: ast.Name) -> None:
        if node.id in self.dependencies and node.id in self.name_parents:
            self.nodes_to_keep.add(self.name_parents[node.id])

        self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign) -> None:
        for target in node.targets:
            if isinstance(target, ast.Name):
                self.name_parents[target.id] = node

        if any(self._beam_depends_on_node(name) for name in ast.walk(node)):
            self.nodes_to_keep.add(node)

            # If a name is assigned a beam object, add it to dependencies
            if isinstance(node.targets[0], ast.Name):
                self.dependencies.add(node.targets[0].id)

        else:
            self.unresolved_nodes.add(node)

        self.generic_visit(node)

    def visit_Call(self, node: ast.Call) -> None:
        if isinstance(node.func, ast.Attribute) and node.parent in self.nodes_to_keep:
            for arg in node.args:
                if isinstance(arg, ast.Name):
                    self.dependencies.add(arg.id)

            for keyword in node.keywords:
                if isinstance(keyword.value, ast.Name):
                    self.dependencies.add(keyword.value.id)

        elif self._beam_depends_on_node(node.func):
            for arg in node.args:
                if isinstance(arg, ast.Name):
                    self.dependencies.add(arg.id)

            for keyword in node.keywords:
                if isinstance(keyword.value, ast.Name):
                    self.dependencies.add(keyword.value.id)

        self.generic_visit(node)

    def _beam_depends_on_node(self, node: ast.AST) -> bool:
        if isinstance(node, ast.Name):
            return node.id in self.beam_imports or node.id in self.dependencies

        elif isinstance(node, ast.Attribute):
            return (
                node.attr in self.beam_imports
                or node.attr in self.dependencies
                or (
                    isinstance(node.value, ast.Name)
                    and node.value.id in self.dependencies
                )
                or self._beam_depends_on_node(node.value)
            )

        elif isinstance(node, ast.FunctionDef):
            return node.name in self.dependencies or node.name in self.dependencies

        elif isinstance(node, ast.Call):
            return self._beam_depends_on_node(node.func)

        elif isinstance(node, ast.Import):
            return any(
                alias.name in self.beam_imports or alias.name in self.dependencies
                for alias in node.names
            )

        elif isinstance(node, ast.ImportFrom):
            return (
                node.module in self.beam_imports
                or node.module in self.dependencies
                or any(
                    alias.name in self.beam_imports or alias.name in self.dependencies
                    for alias in node.names
                )
            )

        elif isinstance(node, ast.Assign):
            return any(
                self._beam_depends_on_node(target) for target in node.targets
            ) or self._beam_depends_on_node(node.value)

        return False

    def _check_node(self, node: ast.AST) -> None:
        if any(self._beam_depends_on_node(name) for name in ast.walk(node)):
            self.nodes_to_keep.add(node)
        else:
            self.unresolved_nodes.add(node)

    def _append_source(self, node: ast.AST, end: str = "\n") -> None:
        source = astor.to_source(node).strip() + end
        self.output_module_source.append(source)

    def visit(self, node: ast.AST) -> None:
        if not hasattr(node, "parent"):
            node.parent = self.current_node

        self.current_node = node
        super().visit(node)
        self.current_node = node.parent

    def dump_source(self) -> str:
        for node in self.unresolved_nodes:
            if self._beam_depends_on_node(node):
                self.nodes_to_keep.add(node)

        nodes_to_keep = list(self.nodes_to_keep)
        nodes_to_keep.sort(key=lambda node: node.lineno)

        for node in nodes_to_keep:
            self._append_source(node)

        return "\n".join(self.output_module_source)


class AppBuilder:
    @staticmethod
    def _find_app_in_module(app_module: ModuleType) -> str:
        app = None
        for member in inspect.getmembers(app_module):
            member_value = member[1]
            if isinstance(member_value, App):
                app = member_value
                break

        if app is not None:
            return json.dumps(app())

        raise RuntimeError("Beam app not found")

    @staticmethod
    def build(*, module_path: str, func_or_app_name: Optional[str]) -> str:
        if not os.path.exists(module_path):
            raise FileNotFoundError

        module_source = ""
        with open(module_path, "r") as f:
            module_source = f.read()

        # Extract app from module source
        tree = ast.parse(module_source)
        extractor = AppExtractor()
        extractor.visit(tree)
        processed_module_source = extractor.dump_source()
        # print(processed_module_source)

        with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as tmp:
            tmp.write(processed_module_source)
            tmp_module_path = tmp.name

        # Override stdout
        stdout = sys.stdout
        sys.stdout = open(os.devnull, "w")

        # Load module
        spec = importlib.util.spec_from_file_location(module_path, tmp_module_path)
        app_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(app_module)

        # Restore stdout
        sys.stdout = stdout
        if func_or_app_name is None:
            config = AppBuilder._find_app_in_module(app_module)
            return AppBuilder._print_config(
                module_path, tmp_module_path, stdout, config
            )

        try:
            _callable = getattr(app_module, func_or_app_name)
            config = json.dumps(_callable())
            return AppBuilder._print_config(
                module_path, tmp_module_path, stdout, config
            )
        except AttributeError:
            raise
        finally:
            os.remove(tmp_module_path)

    @staticmethod
    def _print_config(
        module_path: str, tmp_module_path: str, stdout: TextIO, config: str
    ) -> None:
        config = str(config)
        tmp_module_path = tmp_module_path.lstrip("/")
        config = config.replace(tmp_module_path, module_path)
        stdout.write(config)
        stdout.flush()
        sys.stdout = stdout


if __name__ == "__main__":
    """
    Usage:
        python3 -m beam.build <module_name.py>:<func_name>
            or
        python3 -m beam.build <module_name.py:<app_name>
    """

    app_handler = sys.argv[1]
    module_path = app_handler
    func_or_app_name = None
    try:
        module_path, func_or_app_name = app_handler.split(":")
    except ValueError:
        pass

    AppBuilder.build(module_path=module_path, func_or_app_name=func_or_app_name)
