from logging import Logger
from typing import List

import dns.query
from dns.message import make_query

from kuhl_haus.canary.models.dns_resolver import DnsResolver
from kuhl_haus.canary.models.endpoint_model import EndpointModel
from kuhl_haus.metrics.data.metrics import Metrics


def query_dns(resolvers: List[DnsResolver], ep: EndpointModel, metrics: Metrics, logger: Logger):
    for resolver in resolvers:
        try:
            metrics.set_counter('requests', 1)
            response: dns.message.Message = dns_query(resolver.ip_address, ep.hostname, "A", use_tcp=False)
            metrics.set_counter('responses', 1)
            metrics.attributes['truncated'] = int((dns.flags.TC & response.flags) > 0)
            metrics.attributes['rcode'] = dns.rcode.to_text(response.rcode())
            metrics.attributes['result'] = response.to_text()
            return  # return on first successful DNS response
        except (dns.query.BadResponse, dns.query.UnexpectedSource) as e:
            metrics.attributes['exception'] = repr(e)
            metrics.set_counter('errors', 1)
            metrics.attributes['rcode'] = 'ERROR'
            logger.exception(msg=f"Invalid DNS response for {ep.hostname} from {resolver.ip_address}", exc_info=e)
        except dns.exception.Timeout as e:
            metrics.attributes['exception'] = repr(e)
            metrics.set_counter('errors', 1)
            metrics.attributes['rcode'] = 'TIMEOUT'
            logger.exception(msg=f"DNS timeout querying {ep.hostname} using {resolver.ip_address}", exc_info=e)
        except Exception as e:
            metrics.attributes['exception'] = repr(e)
            metrics.set_counter('exceptions', 1)
            metrics.attributes['rcode'] = 'FATAL'
            logger.exception(msg=f"unhandled exception querying DNS for {ep.hostname} using {resolver.ip_address}", exc_info=e)


def dns_query(ip_address, query_name, rr_type, use_tcp=True) -> dns.message.Message:
    qname = dns.name.from_text(query_name)
    rd_type = dns.rdatatype.from_text(rr_type)
    if use_tcp:
        dns_message = make_query(qname=qname, rdtype=rd_type)
        return dns.query.tcp(dns_message, ip_address, timeout=1)
    else:
        dns_message = make_query(qname=qname, rdtype=rd_type, use_edns=True, ednsflags=0, payload=4096)
        return dns.query.udp(dns_message, ip_address, timeout=1)
