#!python
import subprocess
import argparse
from jinja2 import Template
from os.path import isfile
import yaml
import json
import sys
import re

GPU_DEVICE_PATTERN = re.compile(r'/dev/nvidia\d+')

# support Python 2 or 3
if sys.version_info[0] == 3:
  import urllib.request as request
else:
  import urllib2 as request

parser = argparse.ArgumentParser()
parser.add_argument('-f')

(v, extras) = parser.parse_known_args()

resp = request.urlopen('http://localhost:3476/docker/cli/json').read().decode()
cuda_config = json.loads(resp)

gpu_devices = []
support_devices = []
for dev in cuda_config['Devices']:
    if GPU_DEVICE_PATTERN.match(dev):
        gpu_devices.append(dev)
    else:
        support_devices.append(dev)
gpu_devices.sort()
n_gpu = len(gpu_devices)
volume = cuda_config['Volumes'][0].split(':')[0]

dc_file = v.f or 'docker-compose.yml'
jinja_file = dc_file + '.jinja'

if isfile(jinja_file):
    with open(jinja_file, 'r') as f:
        content = Template(f.read()).render(N_GPU=n_gpu, GPU_DEVICES=gpu_devices)
        config = yaml.load(content)
else:
    with open(dc_file, 'r') as f:
        config = yaml.load(f)

volumes = config.setdefault('volumes', {})
volumes[volume] = {'external': True}

for service, sconf in config['services'].items():
    sconf.setdefault('volumes', []).extend(cuda_config['Volumes'])
    devices = sconf.setdefault('devices', [])
    if not any(gdev in devices for gdev in gpu_devices):
        devices.extend(gpu_devices)
    devices.extend(support_devices)

with open('nvidia-docker-compose.yml', 'w') as f:
    yaml.safe_dump(config, f, default_flow_style=False)

command = ['docker-compose','-f', 'nvidia-docker-compose.yml'] + extras

try:
    subprocess.call(command)
except:
    print('Terminating')

