"""Connection abstraction for interacting with test hosts."""
from __future__ import annotations

import abc
import shlex
import tempfile
import typing as t

from .io import (
    read_text_file,
)

from .config import (
    EnvironmentConfig,
)

from .util import (
    Display,
    OutputStream,
    SubprocessError,
    retry,
)

from .util_common import (
    run_command,
)

from .docker_util import (
    DockerInspect,
    docker_exec,
    docker_inspect,
    docker_network_disconnect,
)

from .ssh import (
    SshConnectionDetail,
    ssh_options_to_list,
)

from .become import (
    Become,
)


class Connection(metaclass=abc.ABCMeta):
    """Base class for connecting to a host."""

    @abc.abstractmethod
    def run(
        self,
        command: list[str],
        capture: bool,
        interactive: bool = False,
        data: t.Optional[str] = None,
        stdin: t.Optional[t.IO[bytes]] = None,
        stdout: t.Optional[t.IO[bytes]] = None,
        output_stream: t.Optional[OutputStream] = None,
    ) -> tuple[t.Optional[str], t.Optional[str]]:
        """Run the specified command and return the result."""

    def extract_archive(
        self,
        chdir: str,
        src: t.IO[bytes],
    ):
        """Extract the given archive file stream in the specified directory."""
        tar_cmd = ['tar', 'oxzf', '-', '-C', chdir]

        retry(lambda: self.run(tar_cmd, stdin=src, capture=True))

    def create_archive(
        self,
        chdir: str,
        name: str,
        dst: t.IO[bytes],
        exclude: t.Optional[str] = None,
    ):
        """Create the specified archive file stream from the specified directory, including the given name and optionally excluding the given name."""
        tar_cmd = ['tar', 'cf', '-', '-C', chdir]
        gzip_cmd = ['gzip']

        if exclude:
            tar_cmd += ['--exclude', exclude]

        tar_cmd.append(name)

        # Using gzip to compress the archive allows this to work on all POSIX systems we support.
        commands = [tar_cmd, gzip_cmd]

        sh_cmd = ['sh', '-c', ' | '.join(shlex.join(command) for command in commands)]

        retry(lambda: self.run(sh_cmd, stdout=dst, capture=True))


class LocalConnection(Connection):
    """Connect to localhost."""

    def __init__(self, args: EnvironmentConfig) -> None:
        self.args = args

    def run(
        self,
        command: list[str],
        capture: bool,
        interactive: bool = False,
        data: t.Optional[str] = None,
        stdin: t.Optional[t.IO[bytes]] = None,
        stdout: t.Optional[t.IO[bytes]] = None,
        output_stream: t.Optional[OutputStream] = None,
    ) -> tuple[t.Optional[str], t.Optional[str]]:
        """Run the specified command and return the result."""
        return run_command(
            args=self.args,
            cmd=command,
            capture=capture,
            data=data,
            stdin=stdin,
            stdout=stdout,
            interactive=interactive,
            output_stream=output_stream,
        )


class SshConnection(Connection):
    """Connect to a host using SSH."""

    def __init__(self, args: EnvironmentConfig, settings: SshConnectionDetail, become: t.Optional[Become] = None) -> None:
        self.args = args
        self.settings = settings
        self.become = become

        self.options = ['-i', settings.identity_file]

        ssh_options: dict[str, t.Union[int, str]] = dict(
            BatchMode='yes',
            StrictHostKeyChecking='no',
            UserKnownHostsFile='/dev/null',
            ServerAliveInterval=15,
            ServerAliveCountMax=4,
        )

        ssh_options.update(settings.options)

        self.options.extend(ssh_options_to_list(ssh_options))

    def run(
        self,
        command: list[str],
        capture: bool,
        interactive: bool = False,
        data: t.Optional[str] = None,
        stdin: t.Optional[t.IO[bytes]] = None,
        stdout: t.Optional[t.IO[bytes]] = None,
        output_stream: t.Optional[OutputStream] = None,
    ) -> tuple[t.Optional[str], t.Optional[str]]:
        """Run the specified command and return the result."""
        options = list(self.options)

        if self.become:
            command = self.become.prepare_command(command)

        options.append('-q')

        if interactive:
            options.append('-tt')

        with tempfile.NamedTemporaryFile(prefix='distronode-test-ssh-debug-', suffix='.log') as ssh_logfile:
            options.extend(['-vvv', '-E', ssh_logfile.name])

            if self.settings.port:
                options.extend(['-p', str(self.settings.port)])

            options.append(f'{self.settings.user}@{self.settings.host}')
            options.append(shlex.join(command))

            def error_callback(ex: SubprocessError) -> None:
                """Error handler."""
                self.capture_log_details(ssh_logfile.name, ex)

            return run_command(
                args=self.args,
                cmd=['ssh'] + options,
                capture=capture,
                data=data,
                stdin=stdin,
                stdout=stdout,
                interactive=interactive,
                output_stream=output_stream,
                error_callback=error_callback,
            )

    @staticmethod
    def capture_log_details(path: str, ex: SubprocessError) -> None:
        """Read the specified SSH debug log and add relevant details to the provided exception."""
        if ex.status != 255:
            return

        markers = [
            'debug1: Connection Established',
            'debug1: Authentication successful',
            'debug1: Entering interactive session',
            'debug1: Sending command',
            'debug2: PTY allocation request accepted',
            'debug2: exec request accepted',
        ]

        file_contents = read_text_file(path)
        messages = []

        for line in reversed(file_contents.splitlines()):
            messages.append(line)

            if any(line.startswith(marker) for marker in markers):
                break

        message = '\n'.join(reversed(messages))

        ex.message += '>>> SSH Debug Output\n'
        ex.message += '%s%s\n' % (message.strip(), Display.clear)


class DockerConnection(Connection):
    """Connect to a host using Docker."""

    def __init__(self, args: EnvironmentConfig, container_id: str, user: t.Optional[str] = None) -> None:
        self.args = args
        self.container_id = container_id
        self.user: t.Optional[str] = user

    def run(
        self,
        command: list[str],
        capture: bool,
        interactive: bool = False,
        data: t.Optional[str] = None,
        stdin: t.Optional[t.IO[bytes]] = None,
        stdout: t.Optional[t.IO[bytes]] = None,
        output_stream: t.Optional[OutputStream] = None,
    ) -> tuple[t.Optional[str], t.Optional[str]]:
        """Run the specified command and return the result."""
        options = []

        if self.user:
            options.extend(['--user', self.user])

        if interactive:
            options.append('-it')

        return docker_exec(
            args=self.args,
            container_id=self.container_id,
            cmd=command,
            options=options,
            capture=capture,
            data=data,
            stdin=stdin,
            stdout=stdout,
            interactive=interactive,
            output_stream=output_stream,
        )

    def inspect(self) -> DockerInspect:
        """Inspect the container and return a DockerInspect instance with the results."""
        return docker_inspect(self.args, self.container_id)

    def disconnect_network(self, network: str) -> None:
        """Disconnect the container from the specified network."""
        docker_network_disconnect(self.args, self.container_id, network)
