#!/usr/bin/python
DOCUMENTATION = '''
---
module: zpool
author: Florian Schulze
short_description: Manage ZFS pools
description:
    - Manage ZFS pools
'''

import os


class Zpool(object):

    platform = 'FreeBSD'

    def __init__(self, module):
        self.module = module
        self.changed = False
        self.state = self.module.params.pop('state')
        self.name = self.module.params.pop('name')
        # set to empty string to create latest supported version on host
        self.version = self.module.params.pop('version')
        self.devices = self.module.params.pop('devices')
        self.raid_mode = self.module.params.pop('raid_mode')
        self.geli = self.module.params.pop('geli')
        self.geli_passphrase_location = self.module.params.pop('geli_passphrase_location')
        self.cmd = self.module.get_bin_path('zpool', required=True)

    def zpool(self, *params):
        return self.module.run_command([self.cmd] + list(params))

    def exists(self):
        (rc, out, err) = self.zpool('list', '-H', self.name)
        return rc==0

    def get_devices(self):
        devices = self.devices
        if not devices:
            if not os.path.exists('/dev/gpt'):
                return []
            return [
                'gpt/%s' % x
                for x in os.listdir('/dev/gpt')
                if x.startswith("%s_" % self.name) and not x.endswith('.eli')]
        return devices.split()

    def get_raid_mode(self, devices):
        raid_mode = self.raid_mode
        if raid_mode == 'detect':
            if len(devices) == 1:
                raid_mode = ''
            elif len(devices) == 2:
                raid_mode = 'mirror'
            else:
                self.module.fail_json(msg="Don't know how to handle %s number of devices (%s)." % (len(devices), ' '.join(devices)))
        return raid_mode

    def prepare_geli_passphrase(self):
        if os.path.exists(self.geli_passphrase_location):
            return
        (rc, out, err) = self.module.run_command([
            'openssl', 'rand', '-base64', '32'])
        if rc != 0:
            self.module.fail_json(msg="Failed to generate '%s'.\n%s" % (self.geli_passphrase_location, err))
        with open(self.geli_passphrase_location, 'w') as f:
            f.write(out)
        os.chmod(self.geli_passphrase_location, 0600)

    def prepare_data_devices(self, devices):
        data_devices = []
        rc_lines = []
        if self.geli:
            self.prepare_geli_passphrase()
            rc_devices = []
            for device in devices:
                data_device = '{device}.eli'.format(device=device)
                (rc, out, err) = self.module.run_command([
                    'geli', 'init', '-s', '4096', '-l', '256',
                    '-J', self.geli_passphrase_location,
                    device])
                if rc != 0:
                    self.module.fail_json(msg="Failed to init geli device '%s'.\n%s" % (data_device, err))
                (rc, out, err) = self.module.run_command([
                    'geli', 'attach',
                    '-j', self.geli_passphrase_location,
                    device])
                if rc != 0:
                    self.module.fail_json(msg="Failed to attach geli device '%s'.\n%s" % (data_device, err))
                data_devices.append(data_device)
                rc_devices.append(device)
            rc_lines.append(dict(
                regexp='^geli_devices=',
                line='geli_devices=\\"%s\\"' % ' '.join(rc_devices)))
            for rc_device in rc_devices:
                rc_lines.append(dict(
                    regexp='^geli_{rc_device}_flags='.format(
                        rc_device=rc_device.replace('/', '_'),
                        geli_passphrase_location=self.geli_passphrase_location),
                    line='geli_{rc_device}_flags=\\"-j {geli_passphrase_location}\\"'.format(
                        rc_device=rc_device.replace('/', '_'),
                        geli_passphrase_location=self.geli_passphrase_location)))
        else:
            for device in devices:
                data_device = '{device}.nop'.format(device=device)
                (rc, out, err) = self.module.run_command([
                    'gnop', 'create', '-S', '4096', device])
                if rc != 0:
                    self.module.fail_json(msg="Failed to create nop device '%s'.\n%s" % (data_device, err))
                data_devices.append(data_device)
        return data_devices, rc_lines

    def cleanup_data_devices(self, data_devices):
        if self.geli:
            return
        (rc, out, err) = self.zpool('export', self.name)
        if rc != 0:
            self.module.fail_json(msg="Failed to export zpool '%s'.\n%s" % (self.name, err))
        for data_device in data_devices:
            (rc, out, err) = self.module.run_command([
                'gnop', 'destroy', data_device])
            if rc != 0:
                self.module.fail_json(msg="Failed to destroy nop device '%s'.\n%s" % (data_device, err))
        (rc, out, err) = self.zpool('import', self.name)
        if rc != 0:
            self.module.fail_json(msg="Failed to import zpool '%s'.\n%s" % (self.name, err))

    def create(self):
        result = dict()
        if self.module.check_mode:
            self.changed = True
            return result
        devices = self.get_devices()
        raid_mode = self.get_raid_mode(devices)
        data_devices, rc_lines = self.prepare_data_devices(devices)
        result['rc_lines'] = rc_lines
        zpool_args = ['create']
        if self.version:
            zpool_args.extend(['-o', 'version=%s' % self.version])
        zpool_args.extend(['-m', 'none', self.name])
        if raid_mode:
            zpool_args.append(raid_mode)
        zpool_args.extend(data_devices)
        (rc, out, err) = self.zpool(*zpool_args)
        if rc != 0:
            self.module.fail_json(msg="Failed to create zpool with the following arguments:\n%s\n%s" % (' '.join(zpool_args), err))
        self.changed = True
        self.cleanup_data_devices(data_devices)
        return result

    def destroy(self):
        raise NotImplemented

    def __call__(self):
        result = dict(name=self.name, state=self.state)

        if self.state in ['present', 'running']:
            if not self.exists():
                result.update(self.create())
        elif self.state == 'absent':
            if self.exists():
                self.destroy()

        if 'rc_lines' not in result:
            result['rc_lines'] = []
        result['changed'] = self.changed
        return result


MODULE_SPECS = dict(
    argument_spec=dict(
        name=dict(required=True, type='str'),
        state=dict(default='present', choices=['present', 'absent'], type='str'),
        version=dict(default='28', type='str'),
        devices=dict(default=None, type='str'),
        raid_mode=dict(default='detect', type='str'),
        geli=dict(default=False, type='bool'),
        geli_passphrase_location=dict(default='/root/geli-passphrase', type='str')),
    supports_check_mode=True)


def main():
    module = AnsibleModule(**MODULE_SPECS)
    result = Zpool(module)()
    if 'failed' in result:
        module.fail_json(**result)
    else:
        module.exit_json(**result)


from ansible.module_utils.basic import *
if __name__ == "__main__":
    main()
