import inspect

import pytest

from pangeo_forge_recipes.patterns import (
    CombineOp,
    ConcatDim,
    FilePattern,
    MergeDim,
    pattern_from_file_sequence,
    prune_pattern,
)


@pytest.fixture
def concat_pattern():
    concat = ConcatDim(name="time", keys=list(range(3)))

    def format_function(time):
        return f"T_{time}"

    return FilePattern(format_function, concat)


def make_concat_merge_pattern(**kwargs):
    times = list(range(3))
    varnames = ["foo", "bar"]
    concat = ConcatDim(name="time", keys=times)
    merge = MergeDim(name="variable", keys=varnames)

    def format_function(time, variable):
        return f"T_{time}_V_{variable}"

    fp = FilePattern(format_function, merge, concat, **kwargs)
    return fp, times, varnames, format_function, kwargs


@pytest.fixture
def concat_merge_pattern():
    return make_concat_merge_pattern()


@pytest.fixture(params=[dict(fsspec_open_kwargs={"block_size": "foo"}), dict(is_opendap=True)])
def concat_merge_pattern_with_kwargs(request):
    return make_concat_merge_pattern(**request.param)


@pytest.fixture(
    params=[
        {},
        dict(fsspec_open_kwargs={"username": "foo", "password": "bar"}),
        dict(query_string_secrets={"token": "foo"}),
    ]
)
def runtime_secrets(request):
    return request.param


def test_file_pattern_concat(concat_pattern):
    fp = concat_pattern
    assert fp.dims == {"time": 3}
    assert fp.shape == (3,)
    assert fp.merge_dims == []
    assert fp.concat_dims == ["time"]
    assert fp.nitems_per_input == {"time": None}
    assert fp.concat_sequence_lens == {"time": None}
    assert len(list(fp)) == 3
    for key, expected_value in zip(fp, ["T_0", "T_1", "T_2"]):
        assert fp[key] == expected_value


def test_pattern_from_file_sequence():
    file_sequence = ["T_0", "T_1", "T_2"]
    fp = pattern_from_file_sequence(file_sequence, "time")
    assert fp.dims == {"time": 3}
    assert fp.shape == (3,)
    assert fp.merge_dims == []
    assert fp.concat_dims == ["time"]
    assert fp.nitems_per_input == {"time": None}
    assert fp.concat_sequence_lens == {"time": None}
    for key in fp:
        assert fp[key] == file_sequence[sorted(key)[0].index]


@pytest.mark.parametrize("pickle", [False, True])
def test_file_pattern_concat_merge(runtime_secrets, pickle, concat_merge_pattern_with_kwargs):

    fp, times, varnames, format_function, kwargs = concat_merge_pattern_with_kwargs

    if runtime_secrets:
        if "fsspec_open_kwargs" in runtime_secrets.keys():
            if not fp.is_opendap:
                fp.fsspec_open_kwargs.update(runtime_secrets["fsspec_open_kwargs"])
            else:
                pytest.skip(
                    "`fsspec_open_kwargs` should never be used in combination with `is_opendap`. "
                    "This is checked in `FilePattern.__init__` but not when updating attributes. "
                    "Proposed changes to secret handling will obviate the need for runtime updates"
                    " to attributes in favor of encryption. So for now, we'll just skip this."
                )
        if "query_string_secrets" in runtime_secrets.keys():
            fp.query_string_secrets.update(runtime_secrets["query_string_secrets"])

    if pickle:
        # regular pickle doesn't work here because it can't pickle format_function
        from cloudpickle import dumps, loads

        fp = loads(dumps(fp))

    assert fp.dims == {"variable": 2, "time": 3}
    assert fp.shape == (2, 3,)
    assert fp.merge_dims == ["variable"]
    assert fp.concat_dims == ["time"]
    assert fp.nitems_per_input == {"time": None}
    assert fp.concat_sequence_lens == {"time": None}
    assert len(list(fp)) == 6
    for key in fp:
        for k in key:
            if k.name == "time":
                assert k.operation == CombineOp.CONCAT
                assert k.sequence_len == 3
                time_val = times[k.index]
            if k.name == "variable":
                assert k.operation == CombineOp.MERGE
                assert k.sequence_len == 2
                variable_val = varnames[k.index]
        expected_fname = format_function(time=time_val, variable=variable_val)
        assert fp[key] == expected_fname

    if "fsspec_open_kwargs" in kwargs.keys():
        assert fp.is_opendap is False
        if "fsspec_open_kwargs" in runtime_secrets.keys():
            kwargs["fsspec_open_kwargs"].update(runtime_secrets["fsspec_open_kwargs"])
        assert fp.fsspec_open_kwargs == kwargs["fsspec_open_kwargs"]
    if "query_string_secrets" in runtime_secrets.keys():
        assert fp.query_string_secrets == runtime_secrets["query_string_secrets"]
    if "is_opendap" in kwargs.keys():
        assert fp.is_opendap == kwargs["is_opendap"]
        assert fp.is_opendap is True
        assert fp.fsspec_open_kwargs == {}


def test_incompatible_kwargs():
    kwargs = dict(fsspec_open_kwargs={"block_size": "foo"}, is_opendap=True)
    with pytest.raises(ValueError):
        make_concat_merge_pattern(**kwargs)
        return


@pytest.mark.parametrize("nkeep", [1, 2])
def test_prune(nkeep, concat_merge_pattern_with_kwargs, runtime_secrets):

    fp = concat_merge_pattern_with_kwargs[0]

    if runtime_secrets:
        if "fsspec_open_kwargs" in runtime_secrets.keys():
            if not fp.is_opendap:
                fp.fsspec_open_kwargs.update(runtime_secrets["fsspec_open_kwargs"])
            else:
                pytest.skip(
                    "`fsspec_open_kwargs` should never be used in combination with `is_opendap`. "
                    "This is checked in `FilePattern.__init__` but not when updating attributes. "
                    "Proposed changes to secret handling will obviate the need for runtime updates"
                    " to attributes in favor of encryption. So for now, we'll just skip this."
                )
        if "query_string_secrets" in runtime_secrets.keys():
            fp.query_string_secrets.update(runtime_secrets["query_string_secrets"])

    fp_pruned = prune_pattern(fp, nkeep=nkeep)
    assert fp_pruned.dims == {"variable": 2, "time": nkeep}
    assert len(list(fp_pruned.items())) == 2 * nkeep

    def get_kwargs(file_pattern):
        sig = inspect.signature(file_pattern.__init__)
        kwargs = {
            param: getattr(file_pattern, param)
            for param in sig.parameters.keys()
            if param not in ["combine_dims"]
        }
        return kwargs

    assert get_kwargs(fp) == get_kwargs(fp_pruned)
