import json
import urllib.parse
from enum import Enum

from pydantic import (
    BaseModel,
    PositiveInt,
)

from fli.models.airline import Airline
from fli.models.airport import Airport
from fli.models.google_flights.base import (
    DisplayMode,
    FlightSegment,
    LayoverRestrictions,
    MaxStops,
    PassengerInfo,
    PriceLimit,
    SeatType,
    SortBy,
    TripType,
)


class FlightSearchFilters(BaseModel):
    """Complete set of filters for flight search.

    This model matches required Google Flights' API structure.
    """

    trip_type: TripType = TripType.ONE_WAY
    passenger_info: PassengerInfo
    flight_segments: list[FlightSegment]
    stops: MaxStops = MaxStops.ANY
    seat_type: SeatType = SeatType.ECONOMY
    price_limit: PriceLimit | None = None
    airlines: list[Airline] | None = None
    max_duration: PositiveInt | None = None
    layover_restrictions: LayoverRestrictions | None = None
    display_mode: DisplayMode = DisplayMode.BEST  # 展示模式：最佳 或 价格最低
    sort_by: SortBy = SortBy.NONE                 # 排序方式：价格、时间、时长等

    def format(self) -> list:
        """Format filters into Google Flights API structure.

        This method converts the FlightSearchFilters model into the specific nested list/dict
        structure required by Google Flights' API.

        The output format matches Google Flights' internal API structure, with careful handling
        of nested arrays and proper serialization of enums and model objects.

        Returns:
            list: A formatted list structure ready for the Google Flights API request

        """

        def serialize(obj):
            if isinstance(obj, Airport) or isinstance(obj, Airline):
                return obj.name
            if isinstance(obj, Enum):
                return obj.value
            if isinstance(obj, list):
                return [serialize(item) for item in obj]
            if isinstance(obj, dict):
                return {key: serialize(value) for key, value in obj.items()}
            if isinstance(obj, BaseModel):
                return serialize(obj.dict(exclude_none=True))
            return obj

        # Format flight segments
        formatted_segments = []
        for segment in self.flight_segments:
            # Format airport codes with correct nesting
            segment_filters = [
                [
                    [
                        [serialize(airport[0]), serialize(airport[1])]
                        for airport in segment.departure_airport
                    ]
                ],
                [
                    [
                        [serialize(airport[0]), serialize(airport[1])]
                        for airport in segment.arrival_airport
                    ]
                ],
            ]

            # Time restrictions
            if segment.time_restrictions:
                time_filters = [
                    segment.time_restrictions.earliest_departure,
                    segment.time_restrictions.latest_departure,
                    segment.time_restrictions.earliest_arrival,
                    segment.time_restrictions.latest_arrival,
                ]
            else:
                time_filters = None

            # Airlines
            airlines_filters = None
            if self.airlines:
                sorted_airlines = sorted(self.airlines, key=lambda x: x.value)
                airlines_filters = [serialize(airline) for airline in sorted_airlines]

            # Layover restrictions
            layover_airports = (
                [serialize(a) for a in self.layover_restrictions.airports]
                if self.layover_restrictions and self.layover_restrictions.airports
                else None
            )
            layover_duration = (
                self.layover_restrictions.max_duration if self.layover_restrictions else None
            )

            # Selected flight (to fetch return flights)
            selected_flights = None
            if self.trip_type == TripType.ROUND_TRIP and segment.selected_flight is not None:
                selected_flights = [
                    [
                        serialize(leg.departure_airport.name),
                        serialize(leg.departure_datetime.strftime("%Y-%m-%d")),
                        serialize(leg.arrival_airport.name),
                        None,
                        serialize(leg.airline.name),
                        serialize(leg.flight_number),
                    ]
                    for leg in segment.selected_flight.legs
                ]

            segment_formatted = [
                segment_filters[0],  # departure airport
                segment_filters[1],  # arrival airport
                time_filters,  # time restrictions
                serialize(self.stops.value),  # stops
                airlines_filters,  # airlines
                None,  # placeholder
                segment.travel_date,  # travel date
                [self.max_duration] if self.max_duration else None,  # max duration
                selected_flights,  # selected flight (to fetch return flights)
                layover_airports,  # layover airports
                None,  # placeholder
                None,  # placeholder
                layover_duration,  # layover duration
                None,  # emissions
                3,  # constant value
            ]
            formatted_segments.append(segment_formatted)

        # Create the main filters structure
        filters = [
            [],  # empty array at start
            [
                None,  # placeholder
                None,  # placeholder
                serialize(self.trip_type.value),
                None,  # placeholder
                [],  # empty array
                serialize(self.seat_type.value),
                [
                    self.passenger_info.adults,
                    self.passenger_info.children,
                    self.passenger_info.infants_on_lap,
                    self.passenger_info.infants_in_seat,
                ],
                [None, self.price_limit.max_price] if self.price_limit else None,
                None,  # placeholder
                None,  # placeholder
                None,  # placeholder
                None,  # placeholder
                None,  # placeholder
                formatted_segments,
                None,  # placeholder
                None,  # placeholder
                None,  # placeholder
                1,  # placeholder (hardcoded to 1)
            ],
            serialize(self.sort_by.value),
            0,  # constant
            0,  # constant
            2,  # constant
        ]

        return filters

    def encode(self, enhanced_search: bool = False) -> str:
        """URL encode the formatted filters for API request.

        Args:
            enhanced_search: If True, use extended search mode (135+ flights)
                           If False, use basic search mode (12 flights)

        Returns:
            URL-encoded filter string for API request
        """
        formatted_filters = self.format()

        # Modify the constants for enhanced search
        if enhanced_search:
            # Change the second constant from 0 to 1 to enable extended search
            # This unlocks 135+ flights instead of just 12
            formatted_filters[-3] = 1  # Change second constant from 0 to 1

        # First convert the formatted filters to a JSON string
        formatted_json = json.dumps(formatted_filters, separators=(",", ":"))
        # Then wrap it in a list with null
        wrapped_filters = [None, formatted_json]
        # Finally, encode the whole thing
        return urllib.parse.quote(json.dumps(wrapped_filters, separators=(",", ":")))

    def encode_with_state_token(
        self,
        enhanced_search: bool = False,
        price_anchor: int | None = None,
        state_token: str | None = None
    ) -> str:
        """URL encode the formatted filters with state token for sorted requests.

        Args:
            enhanced_search: If True, use extended search mode
            price_anchor: Price anchor point for sorting (e.g., 4179)
            state_token: State token from initial response for pagination

        Returns:
            URL-encoded filter string with state token for sorted API request
        """
        formatted_filters = self.format()

        # Modify the constants for enhanced search
        if enhanced_search:
            formatted_filters[-3] = 1

        # 根据您提供的实际f.req结构，添加状态令牌
        if state_token is not None:
            if price_anchor is not None:
                # 添加状态数据块: [[null, price_anchor], "state_token"]
                state_data = [[None, price_anchor], state_token]
            else:
                # 只添加状态令牌，不使用价格锚点: [null, "state_token"]
                state_data = [None, state_token]
            formatted_filters.append(state_data)

        # First convert the formatted filters to a JSON string
        formatted_json = json.dumps(formatted_filters, separators=(",", ":"))
        # Then wrap it in a list with null
        wrapped_filters = [None, formatted_json]
        # Finally, encode the whole thing
        return urllib.parse.quote(json.dumps(wrapped_filters, separators=(",", ":")))
