import contextlib
import dataclasses
import warnings

import alembic.migration
from alembic.autogenerate import produce_migrations, render_python_code
from sqlalchemy import MetaData
from sqlalchemy.engine import Connection

from pytest_alembic.config import duplicate_alembic_config
from pytest_alembic.plugin.error import AlembicTestFailure
from pytest_alembic.runner import MigrationContext
from pytest_alembic.tests.default import NOT_IMPLEMENTED_WARNING


def test_downgrade_leaves_no_trace(alembic_runner: MigrationContext):
    """Assert equal states of the MetaData before and after an upgrade/downgrade cycle.

    This test works by attempting to produce two autogenerated migrations.

    1. The first is the comparison between the original state of the database before the
       given migration's upgrade occurs, and the `MetaData` produced by having performed
       the upgrade.

       This should approximate the autogenerated migration that alembic
       would have generated to produce your upgraded database state itself.

    2. The 2nd is the comparison between the state of the database after having
       performed the upgrade -> downgrade cycle for this revision, and the same
       `MetaData` used in the first comparison.

       This should approximate what alembic would have autogenerated if you
       **actual** performed the downgrade on your database.

    In the event these two autogenerations do not match, it implies that your
    upgrade -> downgrade cycle produces a database state which is different
    (enough for alembic to detect) from the state of the database without having
    performed the migration at all.

    **note** this isn't perfect! Alembic autogeneration will not detect many
    kinds of changes! If you encounter some scenario in which this does not
    detect a change you'd expect it to, alembic already has extensive ability
    to customize and extend the autogeneration capabilities.
    """
    alembic_runner.connection_executor.run_task(
        _test_downgrade_leaves_no_trace, alembic_runner=alembic_runner
    )


def _test_downgrade_leaves_no_trace(connection: Connection, alembic_runner: MigrationContext):
    wrapped_connection = WrappingConnection(connection)

    # Swap the original engine for a connection to enable us to rollback the transaction
    # midway through.
    alembic_config = duplicate_alembic_config(alembic_runner.command_executor.alembic_config)
    alembic_config.attributes["connection"] = wrapped_connection

    alembic_runner = dataclasses.replace(
        alembic_runner,
        connection_executor=dataclasses.replace(
            alembic_runner.connection_executor,
            connection=connection,
        ),
        command_executor=dataclasses.replace(
            alembic_runner.command_executor,
            alembic_config=alembic_config,
        ),
    )

    history = alembic_runner.history
    revisions = history.revisions[:-1]
    if len(revisions) == 1:
        return

    below_minimum = alembic_runner.config.minimum_downgrade_revision is not None
    for revision in revisions:
        if below_minimum and revision == alembic_runner.config.minimum_downgrade_revision:
            below_minimum = False

        # Semantically, we'll solely upgrade for as long as we're below the `minimum_downgrade_revision`,
        # if set. If not set, then this is always done.
        if not below_minimum:
            # Leaves the database in its previous state, to avoid subtle upgrade -> downgrade issues.
            check_revision_cycle(alembic_runner, connection, revision)

        # So we need to proceed by one.
        alembic_runner.migrate_up_to(revision)


def check_revision_cycle(alembic_runner, connection, original_revision):
    migration_context = alembic.migration.MigrationContext.configure(connection)

    # We first need to produce a `MetaData` which represents the state of the database
    # we're trying to get to.
    with connection.begin_nested() as trans:
        alembic_runner.migrate_up_one()
        upgrade_revision = alembic_runner.current

        upgrade_metadata = MetaData()
        upgrade_metadata.reflect(connection)

        # Having procured the target `MetaData`, we need the database back in its original state.
        trans.rollback()

    with connection.begin_nested() as trans:
        # Produce a canonically autogenerated upgrade relative to the original.
        autogenerated_upgrade = produce_migrations(migration_context, upgrade_metadata)
        rendered_autogenerated_upgrade = render_python_code(autogenerated_upgrade.upgrade_ops)

        # Now, we can perform the upgrade -> downgrade cycle!
        alembic_runner.migrate_up_one()
        try:
            alembic_runner.migrate_down_one()
        except NotImplementedError:
            # In the event of a `NotImplementedError`, we should have the same semantics,
            # as-if `minimum_downgrade_revision` was specified, but we'll emit a warning
            # to suggest that feature is used instead.
            warnings.warn(NOT_IMPLEMENTED_WARNING.format(revision=upgrade_revision), stacklevel=1)

        else:
            downgrade_metadata = MetaData()
            downgrade_metadata.reflect(connection)

            # Produce a canonically autogenerated upgrade relative to the post-downgrade state.
            autogenerated_post_downgrade = produce_migrations(migration_context, upgrade_metadata)
            rendered_autogenerated_post_downgrade = render_python_code(
                autogenerated_post_downgrade.upgrade_ops
            )

            if rendered_autogenerated_upgrade != rendered_autogenerated_post_downgrade:
                message = (
                    f"There is a difference between the pre-'{upgrade_revision}'-upgrade `MetaData`, "
                    f"and the post-'{upgrade_revision}'-downgrade `MetaData`. This implies that the "
                    "upgrade performs some set of DDL changes which the downgrade does not "
                    "precisely undo."
                )
                raise AlembicTestFailure(
                    message,
                    context=[
                        (
                            f"DDL diff for {original_revision} -> {upgrade_revision}",
                            rendered_autogenerated_upgrade,
                        ),
                        (
                            f"DDL diff after performing the {upgrade_revision} -> {original_revision} downgrade",
                            rendered_autogenerated_post_downgrade,
                        ),
                    ],
                )
        finally:
            # **This** rollback is to ensure we leave the database back in it's original state for the next revision.
            trans.rollback()


@dataclasses.dataclass
class WrappingConnection:
    connection: Connection

    @contextlib.contextmanager
    def connect(self):
        yield self.connection

    def __getattr__(self, attr):
        return getattr(self.connection, attr)
