#!/usr/bin/env python

import sys
import os
import getopt
import shutil
import stat
import re
import datetime

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from esgcet.publish import multiDirectoryIterator, compareFilesByPath, readDatasetMap
from esgcet.config import getHandlerByName, registerHandlers, loadConfig, initLogging
try:
    import cdat_info
    cdat_info.ping = False
except:
    pass
from cdms2 import Cdunif

usage = """Usage:
    esgcopy_files [options] project directory [directory ...]

    -or-

    esgcopy_files [options] --filelist input_list project

    -or-

    esgcopy_files [options] --map input_mapfile project

    Copy files from scratch space to the ESG archive.

    *Important:* This script is not an officially supported part of the ESGCET module. It supports
    construction of a disk-based archive based on the directory_format_for_copy pattern defined in esg.ini.
    If the files are for replicated datasets (--replica) the directory_format_for_replica pattern is used instead.
    
Arguments:
    project: Project identifier

    directory: A directory containing files to be copied to the archive.

Options:

    --batch batch_number
        Set the integer batch number. The default is YYYYMMDD.

    --checksum-list file:
        If --output is selected, include the checksum field in the generated mapfile.
        Each line of file has the form 'path|checksum'

    --compare, -c:
        Compare file contents to determine if they are equal. By default, the file size
        and tracking_id attributes are used. Using --compare is slower but more accurate.

    --dir-format:
        esg.ini option that has the directory template. Defaults to 'directory_format_for_copy'.

    --dry-run, -n:
        Dry run. Just echo the command would do.

    --filelist path
        Process files in a list. path is a file, each line of which
        has the form file[|checksum]. Dataset identifiers are determined from the file metadata.

    --filter regular_expression:
        Filter files matching the regular expression. Default is '.*\.nc$'
        Regular expression syntax is defined by the Python re module.

    --help, -h:
        Help message.

    --ignore-zero-length:
        Ignore files with length zero. By default, a warning is printed and the file is skipped.

    --map mapfile
        Process the list of files in mapfile. Each line has the form
        dataset_id | path | size4 [| mod_time=... | checksum=... | checksum_type=MD5]

    --move, -m:
        Move the file(s). Default is to copy.

    --output, -o mapfile:
        Output a mapfile that can be input to esgpublish. If mapfile is '-', write to standard output.

    --overwrite, -w:
        Overwrite files that compare as equal. By default, 'equal' files are not copied.
        Files that are unequal cannot be overwritten.

    --prefix string
        Prefix input files.

    --replica, -r:
        Files belong to replicated datasets. Uses 'directory_format_for_replica' configuration option.
        By default, set the product to 'output1'

    --sync-db
        Synchronize file copying with the replica database:
        - If the tracking_id is not in table utd_files, the file is skipped;
        - When the file is copied, the path is updated in the corresponding entry of utd_files.

    --verbose, -v:
        Verbose.

    --version-list file:
        Check the version list for dataset names. If --replica and --output are selected, a warning is issued
        if the generated dataset ID is not contained in the version list. Each line of file
        has the form: 'dataset_id|version'

"""

def canonicalPath(dirname, basename):
    """Generate a canonical path:

    path           canonical_path 
    ---------------------------
    /css02-cmip5   /cmip5_css02
    /cmip5         /cmip5
    """
    if dirname[0:12]=="/css02-cmip5":
        canondirname = "/cmip5_css02"+dirname[12:]
    else:
        canondirname = dirname
    result = os.path.join(canondirname, basename)
    return result

def getInput(filelist):
    fileCount = 0
    while True:

        fileCount += 1

        line = filelist.readline()
        if not line:
            break
        line = line.strip()
        if len(line)>0 and line[0]!='#':
            yield line, fileCount

def getIdentInfo(filelist, prefix=None):
    for line, fileCount in getInput(filelist):

        fields = line.split('|')
        if len(fields)==1:
            path = line
            checksum = None
        elif len(fields)==2:
            path = fields[0].strip()
            checksum = fields[1].strip()
        else:
            print 'Invalid input line: %s'%line
            sys.exit(0)

        if prefix is not None:
            path = os.path.join(prefix, path)

        # size
        try:
            stats = os.stat(path)
        except Exception as e:
            print '[%6d] Error opening %s: %s'%(fileCount, path, e)
            continue
        size = stats.st_size
        mtime = stats.st_mtime
        if size==0:
            print '[%6d] Error: size 0 file: %s, skipping'%(fileCount, path)
            continue

        # Tracking ID and institute
        try:
            f = Cdunif.CdunifFile(path)
        except Exception as e:
            print '[%6d] Cdunif error opening %s: %s'%(fileCount, path, e)
            continue
        tracking_id = f.tracking_id
        institute_id = f.institute_id
        f.close()

        yield tracking_id, (size,mtime), path, checksum, fileCount

def mapfileIterator(datasetMap, filefilt=None, followSymLinks=True, useDiscoveredSize=True):
    keys = datasetMap.keys()
    for dataset_id in keys:
        for path, size in datasetMap[dataset_id]:
            basename = os.path.basename(path)
            try:
                if followSymLinks:
                    st = os.stat(path)
                else:
                    st = os.lstat(path)
            except os.error:
                continue
            if stat.S_ISREG(st.st_mode):
                if useDiscoveredSize:
                    size=st.st_size
                if re.match(filefilt, basename) is not None:
                    yield (path, (size, st.st_mtime))
            
def readDict(path):
    dictmap = {}
    f = open(path)
    lines = f.readlines()
    f.close()
    for line in lines:
        line = line.strip()
        if len(line)==0:
            continue
        if line[0]=='#':
            continue
        fields = line.split('|')
        field1 = fields[0].strip()
        field2 = fields[1].strip()
        dictmap[field1] = field2
    return dictmap

def getSubDir(path, sink, handler, size, compare=False):
    """Return (n, exists) where subdirectory 'n' is such that sink/n/path is unique,
    and 'exists' is True iff the file sink/n/path exists and compares to path.
    """
    # Find the 'integer' directories. If none exist, create '1'
    if not os.path.exists(sink):
        return '1', False
    subdirs = os.listdir(sink)
    intdirs = [int(item) for item in subdirs if item.isdigit()]

    # In descending order, find the first 'n' such that sink/n/path is unique. Create 'n+1' if necessary.
    for n in range(len(intdirs),0,-1):
        nstr = str(n)
        base = os.path.basename(path)
        path2 = os.path.join(sink, nstr, base)
        if os.path.exists(path2):
            if compareFilesByPath(path, path2, handler, size1=size, compare=compare):
                return nstr, True
            else:
                return str(n+1), False
    else:
        return '1', False

def inputIterator(filelist, inputmap, directories=None, datasetMap=None, filefilt=None, prefix=None):
    # Choose an input iterator:
    # If filelist, iterate over (path, checksum);
    # If mapfile, iterate over an input mapfile
    # Otherwise, iterate over directories

    if filelist is not None:
        for tracking_id, sizet, path, checksum, fileCount in getIdentInfo(filelist, prefix=prefix):
            yield path, sizet, tracking_id, checksum
    elif inputmap is not None:
        for path, sizet in mapfileIterator(datasetMap, filefilt=filefilt):
            yield path, sizet, None, None
    else:
        for path, sizet in multiDirectoryIterator(directories, filefilt=filefilt):
            yield path, sizet, None, None

def main(argv):

    try:
        args, lastargs = getopt.getopt(argv, "chmno:rvw", ['batch=', 'checksum-list=', 'compare', 'dir-format=', 'dry-run', 'filelist=', 'filter=', 'help', 'ignore-zero-length', 'map', 'move', 'output=', 'overwrite', 'prefix=', 'replica', 'sync-db', 'verbose', 'version-list='])
    except getopt.error:
        print sys.exc_value
        print usage
        sys.exit(0)

    batch = None
    checksumList = None
    compare = False
    dbname = "esgcet"
    dirformat = None
    dryrun = False
    filefilt = '.*\.nc$'
    filelist = None
    inputmap = None
    ignoreZeroLength = False
    mapfile = None
    movefiles = False
    overwrite = False
    prefix = None
    replica = False
    syncdb = False
    usedirs = True
    verbose = False
    versionList = None
    for flag, arg in args:
        if flag=='--batch':
            batch = int(arg)
        if flag in ['--checksum-list']:
            checksumList = arg
        elif flag in ['-c', '--compare']:
            compare = True
        elif flag=='--dir-format':
            dirformat = arg
        elif flag=='--filter':
            filefilt = arg
        elif flag=='--filelist':
            filelist = open(arg)
            usedirs = False
        elif flag in ['-h', '--help']:
            print usage
            sys.exit(0)
        elif flag in ['--ignore-zero-length']:
            ignoreZeroLength = True
        elif flag=='--map':
            inputmap = arg
            usedirs = False
        elif flag in ['-m', '--move']:
            movefiles = True
        elif flag in ['-n', '--dry-run']:
            dryrun = True
        elif flag in ['-o', '--output']:
            mapfile = arg
        elif flag=='--prefix':
            prefix = arg
        elif flag in ['-r', '--replica']:
            replica = True
        elif flag=='--sync-db':
            syncdb = True
        elif flag in ['-v', '--verbose']:
            verbose = True
        elif flag in ['-w', '--overwrite']:
            overwrite = True
        elif flag in ['--version-list']:
            versionList = arg

    if (usedirs and len(lastargs)<2) or len(lastargs)<1:
        print 'No project or directory specified'
        print usage
        sys.exit(0)

    project = lastargs[0]
    if usedirs:
        directories = lastargs[1:]
    else:
        directories = None

    if dirformat is None:
        if replica:
            dirformat = 'directory_format_for_replica'
        else:
            dirformat = 'directory_format_for_copy'

    if batch is None:
        now = datetime.datetime.now()
        batch = int('%4d%02d%02d'%(now.year, now.month, now.day))

    config = loadConfig(None)
    engine = create_engine(config.get('extract', 'replica_dburl'), echo=False, pool_recycle=3600)
    initLogging('DEFAULT', override_sa=engine)
    Session = sessionmaker(bind=engine, autoflush=True, autocommit=False)
    registerHandlers()
    handler = getHandlerByName(project, None, Session)
    if mapfile is not None:
        if mapfile != '-':
            fmap = open(mapfile, 'w')
        else:
            fmap = sys.stdout

    if versionList is not None:
        versionMap = readDict(versionList)

    if checksumList is not None:
        checksumMap = readDict(checksumList)

    if syncdb:
        from esgcet.query import RDBDataStore
        replicaDb = config.get('extract', 'replica_dburl')
        store = RDBDataStore(replicaDb, dbname=dbname)

    if inputmap is not None:
        datasetMap = readDatasetMap(inputmap)
        path2dset = {}
        for key in datasetMap.keys():
            for path, size in datasetMap[key]:
                if path in path2dset:
                    raise RuntimeError("Duplicate basename in mapfile: %s"%path)
                path2dset[path] = key[0]
    else:
        datasetMap = None

    # Determine the categories
    handler = getHandlerByName(project, None, Session)

    # For each file:
    for path, sizet, tracking_id, checksum in inputIterator(filelist, inputmap, directories=directories, datasetMap=datasetMap, filefilt=filefilt, prefix=prefix):

        # Determine the category values
        handler.resetContext()
        handler.path = path
        try:
            context = handler.getContext()
            if replica:
                context['product'] = 'output1'
        except IOError:
            if ignoreZeroLength:
                continue
            else:
                print 'Error getting context from path=%s, sizet=%s'%(path, sizet)
                raise

        # Build the correct output path
        sink = handler.generateNameFromContext(dirformat)

        # If writing a mapfile, generate the dataset identifier
        if mapfile is not None:
            if usedirs or (filelist is not None):
                datasetId = handler.generateNameFromContext('dataset_id')
            else:
                datasetId = path2dset[path]
            if versionList is not None:
                if datasetId not in versionMap:
                    print 'Dataset is missing in versionlist: id=%s, file=%s, skipping'%(datasetId, path)
                    continue

        # Create the subdirectory if necessary
        subsink, exists = getSubDir(path, sink, handler, sizet[0], compare=compare)
        fullsink = os.path.join(sink, subsink)

        fullpath = canonicalPath(fullsink, os.path.basename(path))
        size, modtime = sizet
        if syncdb:
            inCache, cachePath = store.isCached(tracking_id, size)
            if not inCache:
                if verbose:
                    print 'File not up-to-date, skipping %s'%path
                continue

        # Create the directory if needed
        if not os.path.exists(fullsink):
            if verbose:
                print 'mkdir -p %s'%fullsink
            if not dryrun:
                os.makedirs(fullsink)

        # Copy/move the file
        if (not exists) or overwrite:
            if not movefiles:
                if verbose or dryrun:
                    print 'cp %s %s'%(path, fullsink)
                    if syncdb:
                        print "UPDATE esgf_replica.utd_files SET archived=t, path=%s where tracking_id=%s"%(fullpath, tracking_id)
                if not dryrun:
                    shutil.copy(path, fullsink)
                    if syncdb:
                        store.setArchived(tracking_id, fullpath, modtime=modtime, batch=batch)
            else:
                if verbose or dryrun:
                    print 'mv %s %s'%(path, fullsink)
                    if syncdb:
                        print "UPDATE esgf_replica.utd_files SET archived=t, path=%s where tracking_id=%s"%(fullpath, tracking_id)
                if not dryrun:
                    shutil.move(path, fullsink)
                    if syncdb:
                        store.setArchived(tracking_id, fullpath, modtime=modtime, batch=batch)
            if mapfile is not None:
                size, modtime = sizet
                print >>fmap, "%s | %s | %d | mod_time=%f"%(datasetId, fullpath, size, float(modtime)),
                if checksum is not None:
                    print >>fmap, " | checksum=%s | checksum_type=MD5"%checksum
                elif checksumList is not None:
                    if path in checksumMap:
                        print >>fmap, " | checksum=%s | checksum_type=MD5"%checksumMap[path]
                    else:
                        print >>fmap
                else:
                    print >>fmap

    if filelist is not None:
        filelist.close()
    elif mapfile is not None and mapfile!='-':
        fmap.close()

if __name__=='__main__':
    main(sys.argv[1:])
