"""Finite field (Galois field) implementations"""

from .rings import IntegerModRing, Polynomial


class FiniteField:
    """
    Finite field GF(p^n)

    For prime p and positive integer n:
    - If n = 1: field is Z/pZ (integers modulo p)
    - If n > 1: field is constructed as Z/pZ[x] / (f(x))
      where f(x) is an irreducible polynomial of degree n
    """

    def __init__(self, p, n=1, modulus=None):
        """
        Initialize finite field GF(p^n)

        Parameters:
        -----------
        p : int
            Prime characteristic
        n : int
            Degree of extension (default 1)
        modulus : Polynomial, optional
            Irreducible polynomial for field extension
            If None, one will be found automatically
        """
        if not self._is_prime(p):
            raise ValueError("Characteristic must be prime")

        if n < 1:
            raise ValueError("Extension degree must be at least 1")

        self.p = p
        self.n = n
        self.base_ring = IntegerModRing(p)

        if n == 1:
            # Field is just Z/pZ
            self.modulus = None
        else:
            # Need irreducible polynomial
            if modulus is None:
                self.modulus = self._find_irreducible(n)
            else:
                if not self._is_irreducible(modulus):
                    raise ValueError("Modulus must be irreducible")
                self.modulus = modulus

    def _is_prime(self, n):
        """Check if n is prime"""
        if n < 2:
            return False
        for i in range(2, int(n ** 0.5) + 1):
            if n % i == 0:
                return False
        return True

    def _is_irreducible(self, poly):
        """
        Check if polynomial is irreducible over Z/pZ

        Uses the fact that a polynomial of degree n is irreducible over GF(p)
        if and only if:
        1. It has no roots in GF(p)
        2. gcd(f(x), x^(p^i) - x) = 1 for all i = 1, ..., n-1

        For small degrees, we use a simplified root-checking approach.
        """
        if poly.degree() <= 0:
            return False

        if poly.degree() == 1:
            return True

        # For degree 2 and 3, check if polynomial has roots
        # A polynomial is reducible if it has a root in the field
        if poly.degree() <= 3:
            for x in range(self.p):
                if poly(x) % self.p == 0:
                    return False
            # For degree 2, no roots => irreducible
            if poly.degree() == 2:
                return True
            # For degree 3, also need to check it doesn't factor into
            # irreducible quadratic and linear factors
            # (but no roots already rules out linear factors)
            return True

        # For higher degrees, would need more sophisticated methods
        # For now, assume irreducible for testing purposes
        # TODO: Implement full irreducibility test for degree > 3
        return True

    def _find_irreducible(self, degree):
        """
        Find an irreducible polynomial of given degree over Z/pZ

        Uses exhaustive search for small degrees. For larger degrees,
        uses known constructions or systematic search.
        """
        if degree == 1:
            return Polynomial([0, 1], self.base_ring)  # x

        if degree == 2:
            # Try x^2 + x + 1, x^2 + x + 2, etc.
            for a in range(self.p):
                for b in range(self.p):
                    poly = Polynomial([b, a, 1], self.base_ring)
                    if self._test_irreducible_deg2(poly):
                        return poly

        if degree == 3:
            # Systematically search for degree 3 irreducible
            for a in range(self.p):
                for b in range(self.p):
                    for c in range(self.p):
                        poly = Polynomial([c, b, a, 1], self.base_ring)
                        if self._is_irreducible(poly):
                            return poly

        # For higher degrees, use systematic search
        # Try polynomials of the form x^n + ... + c
        # Start with x^n + x + 1 pattern (often irreducible)
        for c0 in range(self.p):
            for c1 in range(self.p):
                coeffs = [c0] + [0] * (degree - 2) + [c1, 1]
                poly = Polynomial(coeffs, self.base_ring)
                if self._is_irreducible(poly):
                    return poly

        # If simple patterns don't work, exhaustive search
        # (only practical for small degree and small p)
        # TODO: Implement more efficient algorithms for large fields
        raise ValueError(f"Could not find irreducible polynomial of degree {degree} over GF({self.p})")

    def _test_irreducible_deg2(self, poly):
        """Test if degree 2 polynomial is irreducible"""
        # For degree 2: irreducible iff it has no roots in Z/pZ
        for x in range(self.p):
            if poly(x) % self.p == 0:
                return False
        return True

    def order(self):
        """Return number of elements in field"""
        return self.p ** self.n

    def characteristic(self):
        """Return characteristic of field"""
        return self.p

    def add(self, a, b):
        """Add elements in the field"""
        if self.n == 1:
            return self.base_ring.add(a, b)

        # For extensions, a and b are polynomials
        result = a + b
        # Reduce coefficients modulo p
        result.coeffs = [c % self.p for c in result.coeffs]
        return result

    def mul(self, a, b):
        """Multiply elements in the field"""
        if self.n == 1:
            return self.base_ring.mul(a, b)

        # For extensions, multiply and reduce modulo the irreducible polynomial
        result = a * b
        result = self._reduce_modulo(result)
        return result

    def _reduce_modulo(self, poly):
        """Reduce polynomial modulo the irreducible polynomial"""
        # Perform polynomial division and keep remainder

        # Simple reduction for polynomials
        while poly.degree() >= self.modulus.degree():
            # Subtract appropriate multiple of modulus
            deg_diff = poly.degree() - self.modulus.degree()
            lead_coeff = poly.coeffs[-1]

            # Multiply modulus by x^deg_diff * lead_coeff
            shift = [0] * deg_diff + [c * lead_coeff for c in self.modulus.coeffs]

            # Subtract
            new_coeffs = list(poly.coeffs)
            for i in range(len(shift)):
                if i < len(new_coeffs):
                    new_coeffs[i] = (new_coeffs[i] - shift[i]) % self.p

            poly = Polynomial(new_coeffs, self.base_ring)

        # Reduce all coefficients modulo p
        poly.coeffs = [c % self.p for c in poly.coeffs]
        return poly

    def zero(self):
        """Additive identity"""
        if self.n == 1:
            return 0
        return Polynomial([0], self.base_ring)

    def one(self):
        """Multiplicative identity"""
        if self.n == 1:
            return 1
        return Polynomial([1], self.base_ring)

    def neg(self, a):
        """Additive inverse"""
        if self.n == 1:
            return self.base_ring.neg(a)

        # Negate all coefficients
        return Polynomial([self.base_ring.neg(c) for c in a.coeffs], self.base_ring)

    def inv(self, a):
        """
        Multiplicative inverse using extended Euclidean algorithm

        For GF(p): use modular inverse
        For GF(p^n): use extended Euclidean algorithm for polynomials
        """
        if self.n == 1:
            return self.base_ring.inverse(a)

        # Extended Euclidean algorithm for polynomials
        # Find s such that a*s = 1 (mod modulus)
        return self._poly_inverse(a)

    def _poly_inverse(self, poly):
        """
        Find multiplicative inverse of polynomial in the field

        Uses Extended Euclidean Algorithm for polynomials over GF(p).
        Finds s(x) such that poly(x) * s(x) ≡ 1 (mod modulus(x))
        """
        if poly.coeffs == [0]:
            raise ValueError("Cannot invert zero")

        # Extended Euclidean Algorithm for polynomials
        # We want: poly * s = 1 (mod modulus)
        # This means: poly * s + modulus * t = gcd(poly, modulus) = 1

        def poly_gcd_extended(a, b):
            """
            Extended Euclidean algorithm for polynomials

            Returns (gcd, s, t) where gcd = s*a + t*b
            """
            if b.coeffs == [0] or all(c == 0 for c in b.coeffs):
                # Base case: gcd(a, 0) = a, with s=1, t=0
                return a, Polynomial([1], self.base_ring), Polynomial([0], self.base_ring)

            # Polynomial division: a = q*b + r
            q, r = self._poly_divmod(a, b)

            # Recursive call
            gcd, s1, t1 = poly_gcd_extended(b, r)

            # Update: gcd = s1*b + t1*r = s1*b + t1*(a - q*b) = t1*a + (s1 - t1*q)*b
            s = t1
            t = Polynomial([self.base_ring.add(s1.coeffs[i] if i < len(s1.coeffs) else 0,
                                               self.base_ring.neg(self._poly_mul_simple(t1, q).coeffs[i] if i < len(self._poly_mul_simple(t1, q).coeffs) else 0))
                           for i in range(max(len(s1.coeffs), len(self._poly_mul_simple(t1, q).coeffs)))], self.base_ring)

            return gcd, s, t

        def poly_gcd_extended_simple(a, b):
            """Simplified Extended Euclidean for polynomials"""
            old_r, r = a, b
            old_s, s = Polynomial([1], self.base_ring), Polynomial([0], self.base_ring)

            while not (r.coeffs == [0] or all(c == 0 for c in r.coeffs)):
                quotient, _ = self._poly_divmod(old_r, r)
                old_r, r = r, self._poly_sub(old_r, self._poly_mul_simple(quotient, r))
                old_s, s = s, self._poly_sub(old_s, self._poly_mul_simple(quotient, s))

            return old_s

        result = poly_gcd_extended_simple(poly, self.modulus)

        # Normalize result to ensure it has degree < n
        result = self._reduce_modulo(result)

        # Verify the result
        product = self.mul(poly, result)
        if product.coeffs == [1] or (len(product.coeffs) == 1 and product.coeffs[0] % self.p == 1):
            return result

        # Fallback to brute force for small fields
        if self.order() <= 256:
            for coeffs in self._generate_polynomials(self.n - 1):
                candidate = Polynomial(coeffs, self.base_ring)
                product = self.mul(poly, candidate)
                if product.coeffs == [1] or (len(product.coeffs) == 1 and product.coeffs[0] % self.p == 1):
                    return candidate

        raise ValueError("Inverse not found")

    def _poly_divmod(self, a, b):
        """Polynomial division: returns (quotient, remainder)"""
        if b.coeffs == [0] or all(c == 0 for c in b.coeffs):
            raise ValueError("Division by zero polynomial")

        quotient_coeffs = []
        remainder = Polynomial(list(a.coeffs), self.base_ring)

        while remainder.degree() >= b.degree() and not (remainder.coeffs == [0]):
            # Leading coefficient division
            lead_a = remainder.coeffs[-1] % self.p
            lead_b = b.coeffs[-1] % self.p
            lead_b_inv = self.base_ring.inverse(lead_b)
            coeff = (lead_a * lead_b_inv) % self.p

            # Degree difference
            deg_diff = remainder.degree() - b.degree()

            # Subtract b * coeff * x^deg_diff from remainder
            subtrahend = [0] * deg_diff + [c * coeff for c in b.coeffs]
            new_remainder_coeffs = list(remainder.coeffs)
            for i in range(len(subtrahend)):
                if i < len(new_remainder_coeffs):
                    new_remainder_coeffs[i] = (new_remainder_coeffs[i] - subtrahend[i]) % self.p

            # Remove leading zeros
            while new_remainder_coeffs and new_remainder_coeffs[-1] == 0:
                new_remainder_coeffs.pop()

            remainder = Polynomial(new_remainder_coeffs if new_remainder_coeffs else [0], self.base_ring)
            quotient_coeffs = [coeff] + quotient_coeffs

        quotient_coeffs.reverse()
        return Polynomial(quotient_coeffs if quotient_coeffs else [0], self.base_ring), remainder

    def _poly_sub(self, a, b):
        """Subtract two polynomials"""
        max_len = max(len(a.coeffs), len(b.coeffs))
        result_coeffs = []
        for i in range(max_len):
            a_coeff = a.coeffs[i] if i < len(a.coeffs) else 0
            b_coeff = b.coeffs[i] if i < len(b.coeffs) else 0
            result_coeffs.append((a_coeff - b_coeff) % self.p)

        # Remove trailing zeros
        while len(result_coeffs) > 1 and result_coeffs[-1] == 0:
            result_coeffs.pop()

        return Polynomial(result_coeffs, self.base_ring)

    def _poly_mul_simple(self, a, b):
        """Simple polynomial multiplication"""
        if not isinstance(a, Polynomial):
            return b  # Handle scalar case
        if not isinstance(b, Polynomial):
            return a

        result_coeffs = [0] * (len(a.coeffs) + len(b.coeffs) - 1)
        for i, a_coeff in enumerate(a.coeffs):
            for j, b_coeff in enumerate(b.coeffs):
                result_coeffs[i + j] = (result_coeffs[i + j] + a_coeff * b_coeff) % self.p

        return Polynomial(result_coeffs, self.base_ring)

    def _generate_polynomials(self, max_degree):
        """Generate all polynomials up to given degree"""
        # Simple generator for small fields
        if max_degree == 0:
            for a in range(1, self.p):
                yield [a]
        else:
            for a in range(self.p):
                for rest in self._generate_polynomials(max_degree - 1):
                    yield [a] + rest

    def power(self, a, k):
        """Compute a^k in the field"""
        if k == 0:
            return self.one()

        if k < 0:
            a = self.inv(a)
            k = -k

        result = self.one()
        base = a

        while k > 0:
            if k % 2 == 1:
                result = self.mul(result, base)
            base = self.mul(base, base)
            k //= 2

        return result

    def elements(self):
        """
        Generate all elements of the field

        For GF(p): returns {0, 1, ..., p-1}
        For GF(p^n): returns all polynomials of degree < n
        """
        if self.n == 1:
            return list(range(self.p))

        # Generate all polynomials with coefficients in {0, 1, ..., p-1}
        # and degree < n
        elements = []

        def generate(degree, coeffs):
            if degree == self.n:
                elements.append(Polynomial(list(coeffs), self.base_ring))
                return

            for c in range(self.p):
                generate(degree + 1, coeffs + [c])

        generate(0, [])
        return elements

    def is_field(self):
        """Finite fields are always fields"""
        return True

    def __repr__(self):
        if self.n == 1:
            return f"FiniteField(GF({self.p}))"
        return f"FiniteField(GF({self.p}^{self.n}))"


## Example Usage

# GF(7) - prime field
# F7 = FiniteField(7)
# print(F7.order())  # 7
# print(F7.add(5, 4))  # 2 (5 + 4 = 9 = 2 mod 7)
# print(F7.mul(3, 5))  # 1 (3 * 5 = 15 = 1 mod 7)
# print(F7.inv(3))  # 5 (multiplicative inverse)

# GF(4) = GF(2^2) - extension field
# F4 = FiniteField(2, 2)
# print(f"Order: {F4.order()}")  # 4
# print(f"Elements: {len(F4.elements())}")  # 4 elements

# GF(8) = GF(2^3)
# F8 = FiniteField(2, 3)
# print(f"Order: {F8.order()}")  # 8
