# USAGE:
# uvx phototoscan (--images <IMG_DIR> | --image <IMG_PATH>)
# To scan all images in a directory automatically:
# uvx phototoscan --images <IMG_DIR>

# Scanned images will be output to directory named 'output'
from enum import Enum, auto
import itertools
import math
import mimetypes
from pathlib import Path
from typing import Union

import cv2
import numpy as np
import puremagic
from pylsd.lsd import lsd
from scipy.spatial import distance as dist

from phototoscan.pyimagesearch import transform
from phototoscan.pyimagesearch import imutils

class OutputFormat(Enum):
    """Enum for output formats"""
    PATH_STR = auto()
    FILE_PATH = auto()
    BYTES = auto()
    NP_ARRAY = auto()

class ScanMode(Enum):
    """Enum for scan modes"""
    COLOR = auto()
    GRAYSCALE = auto()

class Scanner(object):
    """An image scanner"""

    def __init__(self, MIN_QUAD_AREA_RATIO=0.25, MAX_QUAD_ANGLE_RANGE=40):
        """
        Args:
            MIN_QUAD_AREA_RATIO (float): A contour will be rejected if its corners 
                do not form a quadrilateral that covers at least MIN_QUAD_AREA_RATIO 
                of the original image. Defaults to 0.25.
            MAX_QUAD_ANGLE_RANGE (int):  A contour will also be rejected if the range 
                of its interior angles exceeds MAX_QUAD_ANGLE_RANGE. Defaults to 40.
        """
        self.MIN_QUAD_AREA_RATIO = MIN_QUAD_AREA_RATIO
        self.MAX_QUAD_ANGLE_RANGE = MAX_QUAD_ANGLE_RANGE        

    def filter_corners(self, corners, min_dist=20):
        """Filters corners that are within min_dist of others"""
        def predicate(representatives, corner):
            return all(dist.euclidean(representative, corner) >= min_dist
                       for representative in representatives)

        filtered_corners = []
        for c in corners:
            if predicate(filtered_corners, c):
                filtered_corners.append(c)
        return filtered_corners

    def angle_between_vectors_degrees(self, u, v):
        """Returns the angle between two vectors in degrees"""
        return np.degrees(
            math.acos(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))))

    def get_angle(self, p1, p2, p3):
        """
        Returns the angle between the line segment from p2 to p1 
        and the line segment from p2 to p3 in degrees
        """
        a = np.radians(np.array(p1))
        b = np.radians(np.array(p2))
        c = np.radians(np.array(p3))

        avec = a - b
        cvec = c - b

        return self.angle_between_vectors_degrees(avec, cvec)

    def angle_range(self, quad):
        """
        Returns the range between max and min interior angles of quadrilateral.
        The input quadrilateral must be a numpy array with vertices ordered clockwise
        starting with the top left vertex.
        """
        tl, tr, br, bl = quad
        ura = self.get_angle(tl[0], tr[0], br[0])
        ula = self.get_angle(bl[0], tl[0], tr[0])
        lra = self.get_angle(tr[0], br[0], bl[0])
        lla = self.get_angle(br[0], bl[0], tl[0])

        angles = [ura, ula, lra, lla]
        return np.ptp(angles)          

    def get_corners(self, img):
        """
        Returns a list of corners ((x, y) tuples) found in the input image. With proper
        pre-processing and filtering, it should output at most 10 potential corners.
        This is a utility function used by get_contours. The input image is expected 
        to be rescaled and Canny filtered prior to be passed in.
        """
        lines = lsd(img)

        # massages the output from LSD
        # LSD operates on edges. One "line" has 2 edges, and so we need to combine the edges back into lines
        # 1. separate out the lines into horizontal and vertical lines.
        # 2. Draw the horizontal lines back onto a canvas, but slightly thicker and longer.
        # 3. Run connected-components on the new canvas
        # 4. Get the bounding box for each component, and the bounding box is final line.
        # 5. The ends of each line is a corner
        # 6. Repeat for vertical lines
        # 7. Draw all the final lines onto another canvas. Where the lines overlap are also corners

        corners = []
        if lines is not None:
            # separate out the horizontal and vertical lines, and draw them back onto separate canvases
            lines = lines.squeeze().astype(np.int32).tolist()
            horizontal_lines_canvas = np.zeros(img.shape, dtype=np.uint8)
            vertical_lines_canvas = np.zeros(img.shape, dtype=np.uint8)
            for line in lines:
                x1, y1, x2, y2, _ = line
                if abs(x2 - x1) > abs(y2 - y1):
                    (x1, y1), (x2, y2) = sorted(((x1, y1), (x2, y2)), key=lambda pt: pt[0])
                    cv2.line(horizontal_lines_canvas, (max(x1 - 5, 0), y1), (min(x2 + 5, img.shape[1] - 1), y2), 255, 2)
                else:
                    (x1, y1), (x2, y2) = sorted(((x1, y1), (x2, y2)), key=lambda pt: pt[1])
                    cv2.line(vertical_lines_canvas, (x1, max(y1 - 5, 0)), (x2, min(y2 + 5, img.shape[0] - 1)), 255, 2)

            lines = []

            # find the horizontal lines (connected-components -> bounding boxes -> final lines)
            (contours, hierarchy) = cv2.findContours(horizontal_lines_canvas, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            contours = sorted(contours, key=lambda c: cv2.arcLength(c, True), reverse=True)[:2]
            horizontal_lines_canvas = np.zeros(img.shape, dtype=np.uint8)
            for contour in contours:
                contour = contour.reshape((contour.shape[0], contour.shape[2]))
                min_x = np.amin(contour[:, 0], axis=0) + 2
                max_x = np.amax(contour[:, 0], axis=0) - 2
                left_y = int(np.average(contour[contour[:, 0] == min_x][:, 1]))
                right_y = int(np.average(contour[contour[:, 0] == max_x][:, 1]))
                lines.append((min_x, left_y, max_x, right_y))
                cv2.line(horizontal_lines_canvas, (min_x, left_y), (max_x, right_y), 1, 1)
                corners.append((min_x, left_y))
                corners.append((max_x, right_y))

            # find the vertical lines (connected-components -> bounding boxes -> final lines)
            (contours, hierarchy) = cv2.findContours(vertical_lines_canvas, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            contours = sorted(contours, key=lambda c: cv2.arcLength(c, True), reverse=True)[:2]
            vertical_lines_canvas = np.zeros(img.shape, dtype=np.uint8)
            for contour in contours:
                contour = contour.reshape((contour.shape[0], contour.shape[2]))
                min_y = np.amin(contour[:, 1], axis=0) + 2
                max_y = np.amax(contour[:, 1], axis=0) - 2
                top_x = int(np.average(contour[contour[:, 1] == min_y][:, 0]))
                bottom_x = int(np.average(contour[contour[:, 1] == max_y][:, 0]))
                lines.append((top_x, min_y, bottom_x, max_y))
                cv2.line(vertical_lines_canvas, (top_x, min_y), (bottom_x, max_y), 1, 1)
                corners.append((top_x, min_y))
                corners.append((bottom_x, max_y))

            # find the corners
            corners_y, corners_x = np.where(horizontal_lines_canvas + vertical_lines_canvas == 2)
            corners += zip(corners_x, corners_y)

        # remove corners in close proximity
        corners = self.filter_corners(corners)
        return corners

    def is_valid_contour(self, cnt, IM_WIDTH, IM_HEIGHT):
        """Returns True if the contour satisfies all requirements set at instantitation"""

        return (len(cnt) == 4 and cv2.contourArea(cnt) > IM_WIDTH * IM_HEIGHT * self.MIN_QUAD_AREA_RATIO 
            and self.angle_range(cnt) < self.MAX_QUAD_ANGLE_RANGE)


    def get_contour(self, rescaled_image):
        """
        Returns a numpy array of shape (4, 2) containing the vertices of the four corners
        of the document in the image. It considers the corners returned from get_corners()
        and uses heuristics to choose the four corners that most likely represent
        the corners of the document. If no corners were found, or the four corners represent
        a quadrilateral that is too small or convex, it returns the original four corners.
        """

        # these constants are carefully chosen
        MORPH = 9
        CANNY = 84
        HOUGH = 25

        IM_HEIGHT, IM_WIDTH, _ = rescaled_image.shape

        # convert the image to grayscale and blur it slightly
        gray = cv2.cvtColor(rescaled_image, cv2.COLOR_BGR2GRAY)
        gray = cv2.GaussianBlur(gray, (7,7), 0)

        # dilate helps to remove potential holes between edge segments
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(MORPH,MORPH))
        dilated = cv2.morphologyEx(gray, cv2.MORPH_CLOSE, kernel)

        # find edges and mark them in the output map using the Canny algorithm
        edged = cv2.Canny(dilated, 0, CANNY)
        test_corners = self.get_corners(edged)

        approx_contours = []

        if len(test_corners) >= 4:
            quads = []

            for quad in itertools.combinations(test_corners, 4):
                points = np.array(quad)
                points = transform.order_points(points)
                points = np.array([[p] for p in points], dtype = "int32")
                quads.append(points)

            # get top five quadrilaterals by area
            quads = sorted(quads, key=cv2.contourArea, reverse=True)[:5]
            # sort candidate quadrilaterals by their angle range, which helps remove outliers
            quads = sorted(quads, key=self.angle_range)

            approx = quads[0]
            if self.is_valid_contour(approx, IM_WIDTH, IM_HEIGHT):
                approx_contours.append(approx)

        # also attempt to find contours directly from the edged image, which occasionally 
        # produces better results
        (cnts, hierarchy) = cv2.findContours(edged.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cnts = sorted(cnts, key=cv2.contourArea, reverse=True)[:5]

        # loop over the contours
        for c in cnts:
            # approximate the contour
            approx = cv2.approxPolyDP(c, 80, True)
            if self.is_valid_contour(approx, IM_WIDTH, IM_HEIGHT):
                approx_contours.append(approx)
                break

        # If we did not find any valid contours, just use the whole image
        if not approx_contours:
            TOP_RIGHT = (IM_WIDTH, 0)
            BOTTOM_RIGHT = (IM_WIDTH, IM_HEIGHT)
            BOTTOM_LEFT = (0, IM_HEIGHT)
            TOP_LEFT = (0, 0)
            screenCnt = np.array([[TOP_RIGHT], [BOTTOM_RIGHT], [BOTTOM_LEFT], [TOP_LEFT]])

        else:
            screenCnt = max(approx_contours, key=cv2.contourArea)
            
        return screenCnt.reshape(4, 2)

    def scan(
        self,
        img_input: Union[str, Path, bytes, bytearray, np.ndarray],
        output_format: OutputFormat,
        scan_mode: ScanMode = ScanMode.GRAYSCALE,
        output_dir: Union[str, Path, None] = None,
        output_filename: Union[str, Path, None] = None,
        ext: Union[str, None] = None
    ) -> Union[str, Path, bytes, np.ndarray]:
        mime = None

        if isinstance(img_input, (str, Path)):
            assert Path(img_input).exists(), f"File {img_input} does not exist"

        if isinstance(img_input, (bytes, bytearray, np.ndarray)) and output_format in (OutputFormat.PATH_STR, OutputFormat.FILE_PATH):
            assert output_dir is not None, f"output_dir must be provided for {output_format} output type when img_input is {type(img_input)}"
            assert output_filename is not None, f"output_filename must be provided for {output_format} output type when img_input is {type(img_input)}"

        if isinstance(img_input, np.ndarray) and output_format is OutputFormat.BYTES:
            assert ext is not None, f"ext must be provided for {output_format} output type when img_input is {type(img_input)}"

        if output_format in (OutputFormat.BYTES, OutputFormat.NP_ARRAY):
            assert output_dir is None, "output_dir must be None if output_format is BYTES or NP_ARRAY"

        output_filename_provided = False

        if output_filename is not None:
            assert ext is None, "ext must be None if output_filename is provided"
            output_filename_provided = True
            if isinstance(output_filename, str):
                output_filename = Path(output_filename)
            assert output_filename.stem != "" and output_filename.suffix != "", "output_filename must be a valid filename"
            ext = output_filename.suffix
        elif isinstance(img_input, (str, Path)):
            output_filename = Path(img_input) if isinstance(img_input, str) else img_input
            if ext is None:
                mime = puremagic.from_file(output_filename, mime=True)

        if isinstance(img_input, (bytes, bytearray)):
            raw = bytes(img_input) if isinstance(img_input, bytearray) else img_input
            if ext is None:
                mime = puremagic.from_string(raw, mime=True)

        if ext is not None and mime is None:
            ext = ext.lower()
            ext = f".{ext}" if not ext.startswith('.') else ext
            mime, _ = mimetypes.guess_type(f"dummy{ext}")
            assert mime is not None, f"Invalid ext {ext} in output_filename provided" if output_filename_provided else "Invalid ext {ext} provided"
        
        if mime is not None:
            assert mime.startswith('image/'), f"Invalid mime type {mime} for ext {ext} in output_filename provided" if output_filename_provided else f"Invalid mime type {mime} for ext {ext} provided"
            if ext is None:
                ext = mimetypes.guess_extension(mime)

        if isinstance(img_input, (str, Path)):
            image_path_str = str(img_input) if isinstance(img_input, Path) else img_input
            image = cv2.imread(image_path_str)
        elif isinstance(img_input, (bytes, bytearray)):
            arr = np.frombuffer(img_input, np.uint8)
            image = cv2.imdecode(arr, cv2.IMREAD_COLOR)
        elif isinstance(img_input, np.ndarray):
            image = img_input

        # load the image and compute the ratio of the old height
        # to the new height, clone it, and resize it
        RESCALED_HEIGHT = 500.0

        ratio = image.shape[0] / RESCALED_HEIGHT
        orig = image.copy()
        rescaled_image = imutils.resize(image, height = int(RESCALED_HEIGHT))

        # get the contour of the document
        screenCnt = self.get_contour(rescaled_image)

        # apply the perspective transformation
        warped = transform.four_point_transform(orig, screenCnt * ratio)

        if scan_mode is ScanMode.COLOR:
            # split channels
            b, g, r = cv2.split(warped)
            sharpened_channels = []

            # sharpen each channel
            for ch in (b, g, r):
                blur = cv2.GaussianBlur(ch, (0, 0), 3)
                sharpen = cv2.addWeighted(ch, 1.5, blur, -0.5, 0)
                sharpened_channels.append(sharpen)

            # remerge channels
            processed = cv2.merge(sharpened_channels)
        elif scan_mode is ScanMode.GRAYSCALE:
            # convert the warped image to grayscale
            gray = cv2.cvtColor(warped, cv2.COLOR_BGR2GRAY)

            # sharpen image
            sharpen = cv2.GaussianBlur(gray, (0,0), 3)
            sharpen = cv2.addWeighted(gray, 1.5, sharpen, -0.5, 0)

            # apply adaptive threshold to get black and white effect
            processed = cv2.adaptiveThreshold(sharpen, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 21, 15)

        # save the transformed image
        if output_format in (OutputFormat.PATH_STR, OutputFormat.FILE_PATH):
            output_dir = Path(output_dir) if output_dir is not None else Path(img_input).parent / "output"
            output_filepath = output_dir / output_filename.name
            output_dir.mkdir(parents=True, exist_ok=True)
            cv2.imwrite(str(output_filepath), processed)
            return output_filepath if output_format is OutputFormat.FILE_PATH else str(output_filepath)
        elif output_format is OutputFormat.BYTES:
            _, buffer = cv2.imencode(ext, processed)
            return buffer.tobytes()
        elif output_format is OutputFormat.NP_ARRAY:
            return processed