from __future__ import unicode_literals, print_function
from gocept.net.configfile import ConfigFile
from gocept.net.directory import Directory, exceptions_screened
from netaddr import ip
import argparse
import collections
import configobj
import gocept.net.utils
import os.path as p
import re
import subprocess
import sys


class RR(collections.namedtuple('RR', ['label', 'rtype', 'value'])):

    @classmethod
    def A(cls, label, addr):
        return cls(label, 'A', addr)

    @classmethod
    def AAAA(cls, label, addr):
        return cls(label, 'AAAA', addr)

    @classmethod
    def CNAME(cls, label, cname):
        return cls(label, 'CNAME', cname)

    @classmethod
    def PTR(cls, addr, name):
        return cls(addr, 'PTR', name)


class MixedRecordTypesException(RuntimeError):
    """Raised when CNAME and some other RR collide."""
    pass


class Zone(object):
    """A single DNS zone.

    `name` is the zone name, which is used to construct the file name.
    `origin` specifies a differing base in the DNS tree. If missing, it
    is constructed from the zone name. `include` is a list of static
    zone file snippets to include in the output. `parent_zones` is a
    back pointer to the Zones object containing this zone.
    """

    def __init__(self, name, origin=None, include=[], parent_zones=None):
        self.name = name.rstrip('.')
        self.origin = (origin or name).rstrip('.')
        self.records = []
        self.include = include
        self.parent = parent_zones
        self.cnames = set()
        self.other_rrs = set()

    def fullpath(self):
        """Path to the zone file according to the zones config"""
        return p.join(self.parent.pridir, self.name + '.zone')

    def save(self):
        """Updates zone file on disk. Returns True if anything has changed."""
        old_serial = self.parse_serial()
        if old_serial:
            with open(self.fullpath()) as old:
                if old.read() == self.render(old_serial):
                    return False
        new_serial = max([int(gocept.net.utils.now().strftime('%Y%m%d00')),
                          (old_serial or 0) + 1])
        f = ConfigFile(self.fullpath())
        f.write(self.render(new_serial))
        return f.commit()

    r_serial = re.compile(r'^\s*(\d+) ; serial$', re.M)

    def parse_serial(self):
        """Pulls the serial number from an existing zone file.

        Returns None if the zone file does not exist or the serial
        number cannot be found (e.g., zone file is empty).
        """
        try:
            with open(self.fullpath()) as old:
                old_serial = self.r_serial.search(old.read(4096))
        except IOError:
            return None
        if old_serial:
            return int(old_serial.group(1))
        else:
            return None

    def add_cname(self, rdn, cname):
        if rdn in self.other_rrs:
            raise MixedRecordTypesException(rdn)
        if rdn in self.cnames:
            # one cname is enough :)
            return
        self.records.append(RR.CNAME(rdn, cname))
        self.cnames.add(rdn)

    def render(self, serial):
        """String representation of the zone."""
        res = ["""\
; generated by configure-zones
$TTL 86400
$ORIGIN {origin}.
@               86400   IN      SOA {ns0}. hostmaster.{suffix}. (
                                        {serial} ; serial
                                        10800 ; refresh
                                        900 ; retry
                                        2419200 ; expire
                                        1800 ; neg ttl
                                )
""".format(origin=self.origin, ns0=self.parent.nameservers[0],
           suffix=self.parent.suffix, serial=serial)]
        for ns in self.parent.nameservers:
            res.append(32 * ' ' + 'NS      {}.\n'.format(ns))
        res.append('$TTL {}\n'.format(self.parent.ttl))
        for rr in self.records:
            res.append('{:<31s} {:<7s} {}\n'.format(
                rr.label, rr.rtype, rr.value))
        for includefile in self.include:
            with open(includefile) as f:
                res.append('\n; included from {}\n'.format(includefile) +
                           f.read() + '\n')
        return ''.join(res)


class ForwardZone(Zone):

    def add_a(self, rdn, addr):
        if rdn in self.cnames:
            raise MixedRecordTypesException(rdn)
        self.records.append({4: RR.A(rdn, addr),
                             6: RR.AAAA(rdn, addr)}[addr.version])
        self.other_rrs.add(rdn)


class ReverseZone(Zone):
    """Special-cased zone containing reverse entries.

    This is tricky. `name` is the zone name (used to construct the file
    name for example). `prefix` is the IP subnet for which this zone is
    authoritative. `origin` defaults to `name` but may be overridden to
    represent RFC 2317 CIDR reverse zones. `strip_suffix` is cut from
    reverse DNS entries to construct relative names. Normally, the
    suffix equals the origin, but for RFC 2317 these two differ.
    `include` lists static zone snippets and `parent_zones` is a pointer
    to the containing Zones collection.
    """

    def __init__(self, name, prefix, origin=None, strip_suffix=None,
                 include=[], parent_zones=None):
        super(ReverseZone, self).__init__(name, origin, include, parent_zones)
        self.prefix = prefix
        if strip_suffix:
            self.strip_suffix = '.' + strip_suffix.strip('.') + '.'
        else:
            self.strip_suffix = '.' + self.origin + '.'
        self.private = (self.prefix[0].is_private() and
                        self.prefix[-1].is_private())

    def add_ptr(self, addr, name):
        rev = addr.reverse_dns.replace(self.strip_suffix, '', 1)
        self.records.append(RR.PTR(rev, name))


class Zones(object):

    def __init__(self, config):
        self.config = config
        self.pridir = self.config['settings']['pridir']
        self.ttl = int(self.config['settings']['ttl'])
        self.suffix = self.config['settings']['suffix'].rstrip('.')
        self.nameservers = self.parse_nameservers()
        self.include = self.parse_includes()
        self.external_forward = ForwardZone(
            self.suffix + '-external', self.suffix, self.include['external'],
            self)
        self.internal_forward = ForwardZone(
            self.suffix + '-internal', self.suffix, self.include['internal'],
            self)
        self.reverse_zones = self.create_reverse_zones(self.config['zones'])

    def parse_nameservers(self):
        nameservers = self.config['settings']['nameservers']
        if not isinstance(nameservers, list):
            nameservers = [nameservers]
        return filter(None, [ns.rstrip('.') for ns in nameservers])

    def parse_includes(self):
        """Construct list of file names from "include =" config setting."""
        inc = {}
        for inctype in ('internal', 'external'):
            inc[inctype] = self.config[inctype]['include']
            if not inc[inctype]:
                inc[inctype] = []
            elif not isinstance(inc[inctype], list):
                inc[inctype] = [inc[inctype]]
        return inc

    @staticmethod
    def default_reverse_name(prefix):
        """Compute default reverse zone name for `prefix`.

        Depending on the IP version, we have 4 or 8 bits per dotted DNS
        label. Leverage netaddr's `reverse_dns` function (works for full
        addresses only) by cutting off superfluous labels after
        conversion. Note: this function works only for prefixes with
        dotted label granularity!
        """
        bits_per_dot = {4: 8, 6: 4}[prefix.version]
        keep_labels = prefix.prefixlen // bits_per_dot + 2  # for in-addr, arpa
        labels = prefix[0].reverse_dns.split('.')
        return '.'.join(labels[-keep_labels - 1:])

    def enumerate_reverse_zones(self, netlist):
        """Generates (prefix, ReverseZone) pairs for all configured zones."""
        for prefix, name in netlist.items():
            prefix = ip.IPNetwork(prefix)
            default_name = self.default_reverse_name(prefix)
            yield (prefix, ReverseZone(
                (name or default_name).rstrip('.'), prefix=prefix,
                strip_suffix=default_name, parent_zones=self))

    @staticmethod
    def weight_prefixlen(key):
        net = key[0]
        return net.prefixlen if net.version > 4 else net.prefixlen * 4

    def create_reverse_zones(self, netlist):
        """Creates ordered dict of all reverse zones keyed by prefix.

        The keys are sorted by prefix length (tiny zones last) to
        make an early lookup hit in `add_reverse` more likely.
        """
        return collections.OrderedDict(
            sorted(self.enumerate_reverse_zones(netlist),
                   key=self.weight_prefixlen))

    def add_addr(self, rdn, addr, aliases=[]):
        """Puts A/AAAA and CNAME records into the appropriate zones."""
        self.internal_forward.add_a(rdn, addr)
        for alias in aliases:
            self.internal_forward.add_cname(alias, rdn)
        if not addr.is_private():
            self.external_forward.add_a(rdn, addr)
            for alias in aliases:
                self.external_forward.add_cname(alias, rdn)

    def add_reverse(self, addr, name):
        """Puts PTR records into the appropriate zones.

        `name` is either a FQDN ending with a dot or a relative name to
        *suffix*.
        """
        if not name.endswith('.'):
            name += '.' + self.suffix + '.'
        for prefix in self.reverse_zones:
            if addr in prefix:
                self.reverse_zones[prefix].add_ptr(addr, name)
                return
        raise KeyError('no reverse zone found for address', addr)

    def update_zones(self):
        """Updates all zone files in pridir.

        Returns True if anything has changed.
        """
        changed = False
        for zone in ([self.external_forward, self.internal_forward] +
                     self.reverse_zones.values()):
            changed |= zone.save()
        return changed

    def all_internal_zones(self):
        """Collects all Zone objects for the internal view."""
        return [self.internal_forward] + self.reverse_zones.values()

    def all_external_zones(self):
        """Collects all Zone objects for the internal view."""
        return [self.external_forward] + [
            z for z in self.reverse_zones.values() if not z.private]

    def update_bind_config(self):
        """Updates BIND zones lists.

        Returns True is anything has changed.
        """
        changed = False
        for ztype in ('internal', 'external'):
            f = ConfigFile(self.config[ztype]['zonelist'])
            f.write('// Managed by localconfig-zones: do not edit this file!')
            for zone in getattr(self, 'all_{}_zones'.format(ztype))():
                f.write("""
zone "{origin}" IN {{
    type master;
    file "{filename}";
}};
""".format(origin=zone.origin, filename=zone.fullpath()))
            changed |= f.commit()
        return changed

    def update(self):
        # don't use "or" instead of "|" -> this would short-circuit
        return self.update_zones() | self.update_bind_config()


def join_dn(*labels):
    """Joins non-None name labels to domain name."""
    return '.'.join(filter(None, labels))


class NodeAddr(object):
    """Policy: controls which RRs are created for a node's addresses.

    A NodeAddr object represents a single address on a single interface
    on a single node. This means that we get a lot of NodeAddr objects
    for each node.
    """

    canonical_vlan = 'srv'

    def __init__(self, name, vlan, loc, addr, production=True, reverse=None):
        self.name = name
        self.vlan = vlan
        self.loc = loc
        self.addr = addr
        self.production = production
        self.reverse = reverse

    @property
    def variants(self):
        """List applicable naming variants for this node's address.

        The variant gets inserted between the qualified relative domain
        name and the suffix. For example, (None, 'ipv4') results in
        'vm00.srv.whq.gocept.net' and 'vm00.srv.whq.ipv4.gocept.net'.
        Order matters: The reverse (PTR) name for the address points to
        the first variant.
        """
        if self.addr.version == 4:
            return (None, 'ipv4')
        elif self.addr.version == 6:
            return (None, 'ipv6')
        return KeyError('unsupported IP protocol version', self.addr.version)

    def inject_records(self, zones):
        """Implement naming policy."""
        for index, variant in enumerate(self.variants):
            if self.vlan == self.canonical_vlan:
                # e.g., vm00.{ipv4.,ipv6.,}gocept.net
                # with alias vm00.srv.whq.{ipv4.,ipv6.,}gocept.net
                default_name = join_dn(self.name, variant)
                zones.add_addr(
                    default_name, self.addr,
                    [join_dn(self.name, self.vlan, self.loc, variant)])
            else:
                # e.g., vm00.fe.whq.{ipv4.,ipv6.,}gocept.net
                default_name = join_dn(self.name, self.vlan, self.loc, variant)
                zones.add_addr(default_name, ip.IPAddress(self.addr))
            if index == 0:
                # the first variant determines the reverse address
                zones.add_reverse(self.addr, self.reverse or default_name)


def walk(directory):
    for node in sorted(directory.list_nodes()):
        shortname = node['name']
        location = node['parameters']['location']
        production = node['parameters']['production']
        # sort everything to keep stable ordering
        for vlan, params in sorted(node['parameters']['interfaces'].items()):
            for addresses in sorted(params['networks'].values()):
                for addr in sorted(addresses):
                    reverse = node['parameters']['reverses'].get(addr)
                    yield NodeAddr(shortname, vlan, location,
                                   ip.IPAddress(addr), production, reverse)


def update():
    a = argparse.ArgumentParser()
    a.add_argument('-c', '--config', default='/etc/local/configure-zones.cfg',
                   help='path to configuration file (default: %(default)s)')
    args = a.parse_args()
    config = configobj.ConfigObj(args.config)
    zones = Zones(config)
    directory = Directory()
    with exceptions_screened():
        for node_addr in walk(directory):
            node_addr.inject_records(zones)
    if zones.update() and config['settings'].get('reload'):
        sys.stdout.flush()
        subprocess.check_call([config['settings']['reload']], shell=True)
