import abc
import hashlib
import typing

from ideas import exceptions


class ETagChecksumStrategy(abc.ABC):
    @abc.abstractmethod
    def add_chunk_md5(self, file_data: bytes) -> None:
        pass

    @abc.abstractmethod
    def calculate_final_md5(self) -> str:
        pass


class SimpleETagChecksum(ETagChecksumStrategy):
    md5: str

    def add_chunk_md5(self, file_data: bytes) -> None:
        self.md5 = hashlib.md5(file_data).hexdigest()

    def calculate_final_md5(self) -> str:
        return self.md5


class IncrementalETagChecksum(ETagChecksumStrategy):
    # It's more complicated than just using `hashlib.md5().update()` on each chunk, as AWS
    # calculates it differently. It is calculated as:
    # hexmd5( md5( part1 ) + md5( part2 ) + ... )-{ number of parts }
    #
    # See: https://teppen.io/2018/10/23/aws_s3_verify_etags/
    #
    # To track this, we hold this dictionary of each corresponding part to its chunk's md5 digest.
    # As inserts to a dictionary are thread-safe, no locks are required.
    md5_sums: typing.Dict[int, bytes] = {}
    index = 0

    def add_chunk_md5(self, file_data: bytes) -> None:
        # Calculate the md5 of this chunk and update the total hash
        self.md5_sums[self.index] = hashlib.md5(file_data).digest()
        self.index += 1

    def calculate_final_md5(self) -> str:
        # Strategy 1: calculate md5 hash of entire part

        # Strategy 2: calculate final md5 hash of file parts
        total_digest = hashlib.md5()
        for part_number in range(0, self.index):
            total_digest.update(self.md5_sums[part_number])

        return total_digest.hexdigest()


class ETagChecksum:
    """
    We need to calculate the md5 hash of the file to verify against what AWS S3 returns once the
    download/upload is completed. Unfortunately, there are multiple strategies needed for this:

    1. Object was uploaded in a single PUT request: the entire ETag is the MD5 hexdigest of the
       file. This happens when AWS S3 CLI uploads smaller files to S3.

    2. Object was uploaded in a multipart upload: the ETag is the MD5 hexdigest of each part's MD5
       digest concatenated together, followed by the number of parts separated by a dash.

    References:
    - https://teppen.io/2018/06/23/aws_s3_etags/
    - https://teppen.io/2018/10/23/aws_s3_verify_etags/
    """

    strategy: ETagChecksumStrategy
    etag: str
    parts: int

    def __init__(self, etag: str):
        try:
            etag, parts = etag.split("-")
        except ValueError:
            etag = etag
            parts = 0

        self.etag = etag
        self.parts = int(parts)

        if self.parts == 0:
            self.strategy = SimpleETagChecksum()
        else:
            self.strategy = IncrementalETagChecksum()

    def add_chunk_md5(self, file_data: bytes) -> None:
        self.strategy.add_chunk_md5(file_data)

    def calculate_final_md5(self) -> str:
        return self.strategy.calculate_final_md5()

    def verify(self) -> None:
        local_etag = self.calculate_final_md5()
        if local_etag != self.etag:
            raise exceptions.MismatchedChecksumError(
                f"Local checksum {local_etag} != remote checksum {self.etag}"
            )
