from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
import logging
from queue import Queue
from time import time
from typing import Callable, Generator, Optional
from uuid import uuid4

import jwt
from pydantic import ValidationError
import requests

from kfinance.models.business_relationship_models import BusinessRelationshipType
from kfinance.models.capitalization_models import Capitalizations
from kfinance.models.competitor_models import CompetitorSource
from kfinance.models.date_and_period_models import Periodicity, PeriodType
from kfinance.models.id_models import IdentificationTriple
from kfinance.models.industry_models import IndustryClassification
from kfinance.models.permission_models import Permission
from kfinance.models.price_models import PriceHistory
from kfinance.models.segment_models import SegmentType
from kfinance.pydantic_models import RelationshipResponse, RelationshipResponseNoName


# version.py gets autogenerated by setuptools-scm and is not available
# during local development.
try:
    from kfinance.version import __version__ as kfinance_version
except ImportError:
    kfinance_version = "dev"

logger = logging.getLogger(__name__)


DEFAULT_API_HOST: str = "https://kfinance.kensho.com"
DEFAULT_API_VERSION: int = 1
DEFAULT_OKTA_HOST: str = "https://kensho.okta.com"
DEFAULT_OKTA_AUTH_SERVER: str = "default"
DEFAULT_MAX_WORKERS: int = 10


class KFinanceApiClient:
    def __init__(
        self,
        refresh_token: Optional[str] = None,
        client_id: Optional[str] = None,
        private_key: Optional[str] = None,
        thread_pool: Optional[ThreadPoolExecutor] = None,
        api_host: str = DEFAULT_API_HOST,
        api_version: int = DEFAULT_API_VERSION,
        okta_host: str = DEFAULT_OKTA_HOST,
        okta_auth_server: str = DEFAULT_OKTA_AUTH_SERVER,
    ):
        """Configuration of KFinance Client.

        :param refresh_token: users refresh token
        :type refresh_token: str, Optional
        :param client_id: users client id will be provided by support@kensho.com
        :type client_id: str, Optional
        :param private_key: users private key that corresponds to the registered public sent to support@kensho.com
        :type private_key: str, Optional
        :param thread_pool: the thread pool used to execute batch requests. The number of concurrent requests is
        capped at 10. If no thread pool is provided, a thread pool with 10 max workers will be created when batch
        requests are made.
        :type thread_pool: ThreadPoolExecutor, Optional
        :param api_host: the api host URL
        :type api_host: str
        :param api_version: the api version number
        :type api_version: int
        :param okta_host: the okta host URL
        :type okta_host: str
        :param okta_auth_server: the okta route for authentication
        :type okta_auth_server: str
        """
        if refresh_token is not None:
            self.refresh_token = refresh_token
            self._access_token_refresh_func: Callable[..., str] = (
                self._get_access_token_via_refresh_token
            )
        elif client_id is not None and private_key is not None:
            self.client_id = client_id
            self.private_key = private_key
            self._access_token_refresh_func = self._get_access_token_via_keypair
        else:
            raise RuntimeError("No credentials for any authentication strategy were provided")
        self.api_host = api_host
        self.api_version = api_version
        self.okta_host = okta_host
        self.okta_auth_server = okta_auth_server
        self._thread_pool = thread_pool
        self.url_base = f"{self.api_host}/api/v{self.api_version}/"
        self._access_token_expiry = 0
        self._access_token: str | None = None
        self.user_agent_source = "object_oriented"
        self._batch_id: str | None = None
        self._batch_size: str | None = None
        self._user_permissions: set[Permission] | None = None
        self._endpoint_tracker_queue: Queue[str] | None = None

    @contextmanager
    def batch_request_header(self, batch_size: int) -> Generator:
        """Set batch id and batch size for batch request request headers"""
        batch_id = str(uuid4())

        self._batch_id = batch_id
        self._batch_size = str(batch_size)

        try:
            yield
        finally:
            self._batch_id = None
            self._batch_size = None

    @property
    def thread_pool(self) -> ThreadPoolExecutor:
        """Returns the thread pool used to execute batch requests.

        If the thread pool is not set, a thread pool with 10 max workers will be created
         and returned.
        """

        if self._thread_pool is None:
            self._thread_pool = ThreadPoolExecutor(max_workers=DEFAULT_MAX_WORKERS)

        return self._thread_pool

    @property
    def access_token(self) -> str:
        """Returns the client access token.

        If the token is not set or has expired, a new token gets fetched and returned.
        """
        if self._access_token is None or time() + 60 > self._access_token_expiry:
            self._access_token = self._access_token_refresh_func()
            self._access_token_expiry = jwt.decode(
                self._access_token,
                # nosemgrep:  python.jwt.security.unverified-jwt-decode.unverified-jwt-decode
                options={"verify_signature": False},
            ).get("exp")
            # When the access token gets refreshed, also refresh user permissions in case they
            # have been updated.
            self._refresh_user_permissions()
        return self._access_token

    def _get_access_token_via_refresh_token(self) -> str:
        """Get an access token via oauth by submitting a refresh token."""
        response = requests.get(
            f"{self.api_host}/oauth2/refresh?refresh_token={self.refresh_token}",
            timeout=60,
        )
        response.raise_for_status()
        return response.json().get("access_token")

    def _get_access_token_via_keypair(self) -> str:
        """Get an access token via okta by submitting a registered public key."""
        iat = int(time())
        encoded = jwt.encode(
            {
                "aud": f"{self.okta_host}/oauth2/{self.okta_auth_server}/v1/token",
                "exp": iat + (60 * 60),  # expire in 60 minutes
                "iat": iat,
                "sub": self.client_id,
                "iss": self.client_id,
            },
            self.private_key,
            algorithm="RS256",
        )
        response = requests.post(
            f"{self.okta_host}/oauth2/{self.okta_auth_server}/v1/token",
            headers={
                "Content-Type": "application/x-www-form-urlencoded",
                "Accept": "application/json",
            },
            data={
                "scope": "kensho:app:kfinance",
                "grant_type": "client_credentials",
                "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
                "client_assertion": encoded,
            },
            timeout=60,
        )
        response.raise_for_status()
        return response.json().get("access_token")

    @property
    def user_permissions(self) -> set[Permission]:
        """Return the permissions that the current user holds."""

        if self._user_permissions is None:
            self._refresh_user_permissions()
            # _refresh_user_permissions updates self._user_permissions in place
            assert self._user_permissions is not None
        return self._user_permissions

    def _refresh_user_permissions(self) -> None:
        """Fetches user permissions and stores them as KfinanceApiClient._user_permissions."""

        user_permission_dict = self.fetch_permissions()
        self._user_permissions = set()
        for permission_str in user_permission_dict["permissions"]:
            try:
                self._user_permissions.add(Permission[permission_str])
            except KeyError:
                logger.warning(
                    "You have access to functions using %s. However, functions using "
                    "%s have not yet been added in this version of the client. To access "
                    "all functions that you have access to, you may need to update the client.",
                    permission_str,
                    permission_str,
                )

    @contextmanager
    def endpoint_tracker(self) -> Generator:
        """Context manager to track and return endpoint URLs in our thread-safe queue during execution.

        endpoint_tracker yields a queue into which all endpoint URLs are written until the context manager gets exited.
        It is up to the callers to dequeue the queue before the context manager gets exited and the queue gets wiped.
        This functionality is currently used by `run_with_grounding` to collect and forward endpoint URLs.
        """
        self._endpoint_tracker_queue = Queue[str]()

        try:
            yield self._endpoint_tracker_queue
        finally:
            self._endpoint_tracker_queue = None

    def fetch(self, url: str) -> dict:
        """Does the request and auth"""

        # _endpoint_tracker_queue will only be initialized if inside the endpoint_tracker context manager
        if self._endpoint_tracker_queue:
            self._endpoint_tracker_queue.put(url)

        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.access_token}",
            "User-Agent": f"kfinance/{kfinance_version} {self.user_agent_source}",
        }
        if self._batch_id is not None:
            assert self._batch_size is not None
            headers.update(
                {"Kfinance-Batch-Id": self._batch_id, "Kfinance-Batch-Size": self._batch_size}
            )

        response = requests.get(
            url,
            headers=headers,
            timeout=60,
        )
        response.raise_for_status()
        return response.json()

    def fetch_permissions(self) -> dict[str, list[str]]:
        """Return the permissions of the user."""
        url = f"{self.url_base}users/permissions"
        return self.fetch(url)

    def fetch_id_triple(self, identifier: str, exchange_code: Optional[str] = None) -> dict:
        """Get the ID triple from [identifier]."""
        url = f"{self.url_base}id/{identifier}"
        if exchange_code is not None:
            url = url + f"/exchange_code/{exchange_code}"
        return self.fetch(url)

    def fetch_isin(self, security_id: int) -> dict:
        """Get the ISIN."""
        url = f"{self.url_base}isin/{security_id}"
        return self.fetch(url)

    def fetch_cusip(self, security_id: int) -> dict:
        """Get the CUSIP."""
        url = f"{self.url_base}cusip/{security_id}"
        return self.fetch(url)

    def fetch_primary_security(self, company_id: int) -> dict:
        """Get the primary security of a company."""
        url = f"{self.url_base}securities/{company_id}/primary"
        return self.fetch(url)

    def fetch_securities(self, company_id: int) -> dict:
        """Get the list of securities of a company."""
        url = f"{self.url_base}securities/{company_id}"
        return self.fetch(url)

    def fetch_primary_trading_item(self, security_id: int) -> dict:
        """Get the primary trading item of a security."""
        url = f"{self.url_base}trading_items/{security_id}/primary"
        return self.fetch(url)

    def fetch_trading_items(self, security_id: int) -> dict:
        """Get the list of trading items of a security."""
        url = f"{self.url_base}trading_items/{security_id}"
        return self.fetch(url)

    def fetch_history(
        self,
        trading_item_id: int,
        is_adjusted: bool = True,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        periodicity: Optional[Periodicity] = None,
    ) -> PriceHistory:
        """Get the pricing history."""
        url = (
            f"{self.url_base}pricing/{trading_item_id}/"
            f"{start_date if start_date is not None else 'none'}/"
            f"{end_date if end_date is not None else 'none'}/"
            f"{periodicity if periodicity else 'none'}/"
            f"{'adjusted' if is_adjusted else 'unadjusted'}"
        )
        return PriceHistory.model_validate(self.fetch(url))

    def fetch_history_metadata(self, trading_item_id: int) -> dict[str, str]:
        """Get the pricing history metadata."""
        url = f"{self.url_base}pricing/{trading_item_id}/metadata"
        return self.fetch(url)

    def fetch_market_caps_tevs_and_shares_outstanding(
        self,
        company_id: int,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
    ) -> Capitalizations:
        """Get the market cap, TEV, and shares outstanding for a company."""
        url = (
            f"{self.url_base}market_cap/{company_id}/"
            f"{start_date if start_date is not None else 'none'}/"
            f"{end_date if end_date is not None else 'none'}"
        )
        return Capitalizations.model_validate(self.fetch(url))

    def fetch_segments(
        self,
        company_id: int,
        segment_type: SegmentType,
        period_type: Optional[PeriodType] = None,
        start_year: Optional[int] = None,
        end_year: Optional[int] = None,
        start_quarter: Optional[int] = None,
        end_quarter: Optional[int] = None,
    ) -> dict:
        """Get a specified segment type for a specified duration."""
        url = (
            f"{self.url_base}segments/{company_id}/{segment_type}/"
            f"{period_type if period_type else 'none'}/"
            f"{start_year if start_year is not None else 'none'}/"
            f"{end_year if end_year is not None else 'none'}/"
            f"{start_quarter if start_quarter is not None else 'none'}/"
            f"{end_quarter if end_quarter is not None else 'none'}"
        )
        return self.fetch(url)

    def fetch_price_chart(
        self,
        trading_item_id: int,
        is_adjusted: bool = True,
        start_date: Optional[str] = None,
        end_date: Optional[str] = None,
        periodicity: Optional[Periodicity] = None,
    ) -> bytes:
        """Get the price chart."""
        url = (
            f"{self.url_base}price_chart/{trading_item_id}/"
            f"{start_date if start_date is not None else 'none'}/"
            f"{end_date if end_date is not None else 'none'}/"
            f"{periodicity if periodicity else 'none'}/"
            f"{'adjusted' if is_adjusted else 'unadjusted'}"
        )

        response = requests.get(
            url,
            headers={
                "Content-Type": "image/png",
                "Authorization": f"Bearer {self.access_token}",
            },
            timeout=60,
        )
        response.raise_for_status()
        return response.content

    def fetch_statement(
        self,
        company_id: int,
        statement_type: str,
        period_type: Optional[PeriodType] = None,
        start_year: Optional[int] = None,
        end_year: Optional[int] = None,
        start_quarter: Optional[int] = None,
        end_quarter: Optional[int] = None,
    ) -> dict:
        """Get a specified financial statement for a specified duration."""
        url = (
            f"{self.url_base}statements/{company_id}/{statement_type}/"
            f"{period_type if period_type else 'none'}/"
            f"{start_year if start_year is not None else 'none'}/"
            f"{end_year if end_year is not None else 'none'}/"
            f"{start_quarter if start_quarter is not None else 'none'}/"
            f"{end_quarter if end_quarter is not None else 'none'}"
        )
        return self.fetch(url)

    def fetch_line_item(
        self,
        company_id: int,
        line_item: str,
        period_type: Optional[PeriodType] = None,
        start_year: Optional[int] = None,
        end_year: Optional[int] = None,
        start_quarter: Optional[int] = None,
        end_quarter: Optional[int] = None,
    ) -> dict:
        """Get a specified financial line item for a specified duration."""
        url = (
            f"{self.url_base}line_item/{company_id}/{line_item}/"
            f"{period_type if period_type else 'none'}/"
            f"{start_year if start_year is not None else 'none'}/"
            f"{end_year if end_year is not None else 'none'}/"
            f"{start_quarter if start_quarter is not None else 'none'}/"
            f"{end_quarter if end_quarter is not None else 'none'}"
        )
        return self.fetch(url)

    def fetch_info(self, company_id: int) -> dict:
        """Get the company info."""
        url = f"{self.url_base}info/{company_id}"
        return self.fetch(url)

    def fetch_earnings_dates(self, company_id: int) -> dict:
        """Get the earnings dates."""
        url = f"{self.url_base}earnings/{company_id}/dates"
        return self.fetch(url)

    def fetch_geography_groups(
        self, country_iso_code: str, state_iso_code: Optional[str] = None, fetch_ticker: bool = True
    ) -> dict[str, list]:
        """Fetch geography groups"""
        url = f"{self.url_base}{'ticker_groups' if fetch_ticker else 'company_groups'}/geo/country/{country_iso_code}"
        if state_iso_code:
            url = url + f"/{state_iso_code}"
        return self.fetch(url)

    def fetch_ticker_geography_groups(
        self,
        country_iso_code: str,
        state_iso_code: Optional[str] = None,
    ) -> list[IdentificationTriple]:
        """Fetch ticker geography groups"""
        return self._tickers_response_to_id_triple(
            self.fetch_geography_groups(
                country_iso_code=country_iso_code, state_iso_code=state_iso_code, fetch_ticker=True
            )
        )

    def fetch_company_geography_groups(
        self,
        country_iso_code: str,
        state_iso_code: Optional[str] = None,
    ) -> dict[str, list[int]]:
        """Fetch company geography groups"""
        return self.fetch_geography_groups(
            country_iso_code=country_iso_code, state_iso_code=state_iso_code, fetch_ticker=False
        )

    def fetch_exchange_groups(
        self, exchange_code: str, fetch_ticker: bool = True
    ) -> dict[str, list]:
        """Fetch exchange groups"""
        url = f"{self.url_base}{'ticker_groups' if fetch_ticker else 'trading_item_groups'}/exchange/{exchange_code}"
        return self.fetch(url)

    def fetch_ticker_exchange_groups(self, exchange_code: str) -> list[IdentificationTriple]:
        """Fetch ticker exchange groups"""
        return self._tickers_response_to_id_triple(
            self.fetch_exchange_groups(
                exchange_code=exchange_code,
                fetch_ticker=True,
            )
        )

    def fetch_trading_item_exchange_groups(self, exchange_code: str) -> dict[str, list[int]]:
        """Fetch company exchange groups"""
        return self.fetch_exchange_groups(
            exchange_code=exchange_code,
            fetch_ticker=False,
        )

    @staticmethod
    def _tickers_response_to_id_triple(
        tickers_response: dict[str, list[dict]],
    ) -> list[IdentificationTriple]:
        """For fetch ticker cases with a dict[str, list[dict]] response, return a list[IdentificationTriple].

        For example, with a given fetch tickers response:
        {"tickers" : [{"trading_item_id": 1, "security_id": 1, "company_id": 1}, {"trading_item_id": 2,"security_id": 2,"company_id": 2}]},
        return [[1, 1, 1], [2, 2, 2]].
        """
        return [
            IdentificationTriple(
                trading_item_id=ticker["trading_item_id"],
                security_id=ticker["security_id"],
                company_id=ticker["company_id"],
            )
            for ticker in tickers_response["tickers"]
        ]

    def fetch_ticker_combined(
        self,
        country_iso_code: Optional[str] = None,
        state_iso_code: Optional[str] = None,
        simple_industry: Optional[str] = None,
        exchange_code: Optional[str] = None,
    ) -> list[IdentificationTriple]:
        """Fetch tickers using combined filters route"""
        if (
            country_iso_code is None
            and state_iso_code is None
            and simple_industry is None
            and exchange_code is None
        ):
            raise RuntimeError("Invalid parameters: No parameters provided or all set to none")
        elif country_iso_code is None and state_iso_code is not None:
            raise RuntimeError(
                "Invalid parameters: country_iso_code must be provided with a state_iso_code value"
            )
        else:
            url = f"{self.url_base}ticker_groups/filters/geo/{str(country_iso_code).lower()}/{str(state_iso_code).lower()}/simple/{str(simple_industry).lower()}/exchange/{str(exchange_code).lower()}"
            return self._tickers_response_to_id_triple(self.fetch(url))

    def fetch_companies_from_business_relationship(
        self, company_id: int, relationship_type: BusinessRelationshipType
    ) -> RelationshipResponse | RelationshipResponseNoName:
        """Fetches a dictionary of current and previous company IDs and names associated with a given company ID based on the specified relationship type.

        Example: fetch_companies_from_business_relationship(company_id=1234, relationship_type="distributor") returns a dictionary of company 1234's current and previous distributors.

        As of 2024-05-28, we are changing the response on the backend from
        RelationshipResponseNoName to RelationshipResponse. This function can handle both response
        types.

        :param company_id: The ID of the company for which associated companies are being fetched.
        :type company_id: int
        :param relationship_type: The type of relationship to filter by. Valid relationship types are defined in the BusinessRelationshipType class.
        :type relationship_type: BusinessRelationshipType
        :return: A dictionary containing lists of current and previous company IDs that have the specified relationship with the given company_id.
        :rtype: RelationshipResponse | RelationshipResponseNoName
        """
        url = f"{self.url_base}relationship/{company_id}/{relationship_type}"
        result = self.fetch(url)
        # Try to parse as the newer RelationshipResponse and fall back to
        # RelationshipResponseNoName if that fails.
        try:
            return RelationshipResponse.model_validate(result)
        except ValidationError:
            return RelationshipResponseNoName.model_validate(result)

    def fetch_ticker_from_industry_code(
        self,
        industry_code: str,
        industry_classification: IndustryClassification,
    ) -> list[IdentificationTriple]:
        """Fetches a list of identification triples that are classified in the given industry_code and industry_classification.

        Returns a dictionary of shape {"tickers": List[{“company_id”: <company_id>, “security_id”: <security_id>, “trading_item_id”: <trading_item_id>}]}.
        :param industry_code: The industry_code to filter on. The industry_code is a string corresponding to the Industry classifications ontology.
        :type industry_code: str
        :param industry_classification: The type of industry_classification to filter on.
        :type industry_classification: IndustryClassification
        :return: A list of identification triples [company_id, security_id, trading_item_id] that are classified in the given industry_code and industry_classification.
        :rtype: list[IdentificationTriple]
        """
        url = f"{self.url_base}ticker_groups/industry/{industry_classification}/{industry_code}"
        return self._tickers_response_to_id_triple(self.fetch(url))

    def fetch_company_from_industry_code(
        self,
        industry_code: str,
        industry_classification: IndustryClassification,
    ) -> dict[str, list[int]]:
        """Fetches a list of companies that are classified in the given industry_code and industry_classification.

        Returns a dictionary of shape {"companies": List[<company_id>]}.
        :param industry_code: The industry_code to filter on. The industry_code is a string corresponding to the Industry classifications ontology.
        :type industry_code: str
        :param industry_classification: The type of industry_classification to filter on.
        :type industry_classification: IndustryClassification
        :return: A dictionary containing the list of companies that are classified in the given industry_code and industry_classification.
        :rtype: dict[str, list[int]]
        """
        url = f"{self.url_base}company_groups/industry/{industry_classification}/{industry_code}"
        return self.fetch(url)

    def fetch_earnings(self, company_id: int) -> dict:
        """Get the earnings for a company."""
        url = f"{self.url_base}earnings/{company_id}"
        return self.fetch(url)

    def fetch_transcript(self, key_dev_id: int) -> dict:
        """Get the transcript for an earnings item."""
        url = f"{self.url_base}transcript/{key_dev_id}"
        return self.fetch(url)

    def fetch_competitors(self, company_id: int, competitor_source: CompetitorSource) -> dict:
        """Get the competitors for a company."""
        url = f"{self.url_base}competitors/{company_id}"
        if competitor_source is not CompetitorSource.all:
            url = url + f"/{competitor_source}"
        return self.fetch(url)

    def fetch_mergers_for_company(
        self,
        company_id: int,
    ) -> dict[str, list[dict[str, int | str]]]:
        """Fetches the mergers and acquisitions the given company was involved in.

        Returns a dictionary of shape {"target", [{"transaction_id": <transaction_id>, "merger_title": <merger short title>}], "buyer": [...], "seller": [...]}
        :param company_id: The company ID to filter on.
        :type company_id: int
        :return: A dictionary containing transaction IDs and 'merger titles' for each of the three kinds of roles the given company could be party to.
        :rtype: dict[str, list[dict[str, int | str]]]
        """
        url = f"{self.url_base}mergers/{company_id}"
        return self.fetch(url)

    def fetch_merger_info(
        self,
        transaction_id: int,
    ) -> dict:
        """Fetches information about the given merger or acquisition, including the timeline, the participants, and the considerations.

        Returns a complex dictionary.
        :param transaction_id: The transaction ID to filter on.
        :type transaction_id: int
        :return: A dictionary containing the timeline, the participants, and the considerations (eith their details) of the transaction.
        :rtype: dict
        """
        url = f"{self.url_base}merger/info/{transaction_id}"
        return self.fetch(url)

    def fetch_advisors_for_company_in_merger(
        self,
        transaction_id: int,
        advised_company_id: int,
    ) -> dict[str, list[dict[str, int | str]]]:
        """Fetch information about the advisors of a given company involved in a given merger or acquisition.

        Returns a dictionary of shape {"advisors": [{"advisor_company_id": <advisor_company_id>, "advisor_company_name": <advisor_company_name>, "advisor_type_name": <advisor_type_name>},...]}
        :param transaction_id: The transaction ID to filter on.
        :type transaction_id: int
        :param advised_company_id: The company ID involved with the transaction.
        :type advised_company_id: int
        :return: A dictionary containing the list of companies advising a company involved with a merger or acquisition, along with their advisor type.
        :rtype: dict[str, list[dict[str, int | str]]]
        """
        url = f"{self.url_base}merger/info/{transaction_id}/advisors/{advised_company_id}"
        return self.fetch(url)
