import inspect
import os
 
from typing import Any, Callable, TypeVar, ParamSpec, NamedTuple


ReturnType = TypeVar("ReturnType")  # the callable/awaitable return type
Param = ParamSpec("Param")  # the callable parameters



class NotFullyTypedError(Exception):
    """Raised when a function is not fully typed."""

    pass



def generate_test_case(func: Callable[Param, ReturnType], is_pytest: bool = False) -> str: # type: ignore
    sig: inspect.Signature = inspect.signature(func)

    # check if all parameters are annotated
    for param in sig.parameters.values():
        if param.annotation is inspect.Parameter.empty:
            raise NotFullyTypedError(
                f"Parameter '{param.name}' of function '{func.__name__}' is not typed"
            )
    # Check if return type is annotated
    if sig.return_annotation is inspect.Signature.empty:
        raise NotFullyTypedError(f"Return type of function '{func.__name__}' is not typed")
    
    func_name = func  # ugly hack

    # func used as template
    def test_func():

        class Args(NamedTuple):
            ...

        class Test(NamedTuple):
            name: str
            args: Args
            want: ReturnType

        cases: tuple[Test] = (  # type: ignore
            # TODO: add cases here
        )
        
        print(f"running test cases for func_name function:\n")
        for case in cases:
            if (v := func_name(*case.args)) != case.want:  # type: ignore
                print(f"{case.name} func_name got {v} wanted {case.want}")
        print('Test complited')

    def pytested_test_func():
        import pytest

        class Args(NamedTuple):
            ...


        class Test(NamedTuple):
            name: str
            args: Args
            want: ReturnType



        @pytest.mark.parametrize(
            "test_case",
            [
                Test(name="name of test1", args=Args(), want='want_holder'),  # type: ignore
            ],
        )
        def test_kek(test_case: Test):
            assert func_name(*test_case.args) == test_case.want

    args_txt = ""
    comma_separated_args = ""
    cnt = 0
    for _, v in sig.parameters.items():
        if cnt != 0:
            args_txt += " " * 12 + str(v) + "\n"
            comma_separated_args += str(v.name) + "=, "
        else:
            args_txt = str(v) + "\n"
            cnt += 1
            comma_separated_args += str(v.name) + "=, "

    res = (
        inspect.getsource(pytested_test_func if is_pytest else test_func)
        .replace("  # type: ignore", "")
        .replace("test_func", f"test_{func.__name__}")
        .replace("test_kek", f"test_{func.__name__}")
        .replace(f"{'pytested_test_func' if is_pytest else 'test_func'}", f"{func.__name__}")
        #        .replace("ReturnType", sig.return_annotation)
        .replace("...", args_txt)
        .replace("import pytest", "")
        .replace("Args()",f"Args({comma_separated_args})")
        .replace("'want_holder'",sig.return_annotation.__name__ if hasattr(sig.return_annotation,'__name__') else type(sig.return_annotation).__name__)
        .replace("func_name", str(func.__name__))
    )
    tmp = res.splitlines(keepends=True)
    final_res = ""
    for line in tmp:
        final_res += line[4:] if len(line) > 4 else line
    return final_res

def save_to_file(func: Callable[Param, Any], file_path: str| None):
    # by default to the same directory where funtion is defined
    if file_path is None:
        file_path = inspect.getfile(func).replace(".py", "_test.py")
    if os.path.isfile(file_path):
        # TODO: read file, count number of test cases, add new test case
        return
    with open(file_path, "w") as f:
        f.write('"""This file is generated by test_case_generator"""\n')
        f.write("from typing import NamedTuple\n\n")
        f.write(f'from {inspect.getfile(func).replace(".py", "").split("/")[-1]} import {func.__name__}\n\n')
        f.write(generate_test_case(func,True))

def append_to_file(file_path: str, func: Callable[Param, Any]):
    with open(file_path, "a") as f:
        f.write("\n\n")
        f.write(generate_test_case(func))



if __name__ == "__main__":

    def kek(foo: int, bar: str) -> str:
        return str(foo) + bar

    print(generate_test_case(kek))
    print(generate_test_case(kek,True))
    #print(generate_test_case(generate_test_case))
    #save_to_file(generate_test_case, "test.py")
