#!/Users/mathewj2/repos/SPT/venv/bin/python
import sys
import os
from os.path import basename
from os.path import exists
from os.path import join
import re
import glob
import json
import datetime
import collections
from collections import OrderedDict
import importlib.resources
import base64

import pandas as pd
import jinja2
from jinja2 import Environment
from jinja2 import BaseLoader
jinja_environment = Environment(loader=BaseLoader)

def quote_hash(input):
    return re.sub('\#', '\\#', input)
jinja_environment.filters['quote_hash'] = quote_hash

ansi_escape = re.compile(r'''
    \x1B  # ESC
    (?:   # 7-bit C1 Fe (except CSI)
        [@-Z\\-_]
    |     # or [ for CSI, followed by a control sequence
        \[
        [0-?]*  # Parameter bytes
        [ -/]*  # Intermediate bytes
        [@-~]   # Final byte
    )
''', re.VERBOSE)


class LogParsingError(Exception):
    def __init__(self, message):
        self.message = message

    def __str__(self):
        return(repr(self.message))


class LSFPreambleSkipper:
    def __init__ (self, filename):
        header = open(filename, 'rt').readline().rstrip('\n')
        if re.match('^Sender: LSF System <[\w\d\_\.\@\-]+>$', header):
            seek_to_stdout_capture = True
        else:
            seek_to_stdout_capture = False
        self.f = open(filename, 'rt')
        if seek_to_stdout_capture:
            line = None
            while line != 'The output (if any) follows:\n':
                line = self.f.readline()
            self.f.readline()
            position = self.f.tell()

            line = self.f.readline()
            line = ansi_escape.sub('', line)
            while re.search('^WARNING: ', line):
                position = self.f.tell()
                line = self.f.readline()
                line = ansi_escape.sub('', line)
            self.f.seek(position)

    def __enter__ (self):
        return self.f

    def __exit__ (self, exc_type, exc_value, traceback):
        self.f.close()

class LogParser:
    def __init__(self, path):
        self.path = path
        self.extractions = {}
        self.performance_report_base64 = ''
        self.performance_report_contents = ''

    def get_path(self):
        return self.path

    def get_inputs(self):
        path = self.get_path()
        filenames = {
            'config' : join(path, '.spt_pipeline.json'),
            'nextflow log' : join(path, '.nextflow.log'),
            'performance report' : join(path, 'results', 'performance_report.md')
        }
        config_file = self.check_file(filenames['config'])
        nextflow_log = self.check_file(filenames['nextflow log'])
        performance_report = self.check_file(filenames['performance report'])
        log_files = glob.glob(join(path, 'work/*/*/.command.log'))
        if len(log_files) == 0:
            raise LogParsingError('No log files found.')
        return [log_files, config_file, nextflow_log, performance_report]

    def check_file(self, target):
        if exists(target):
            return target
        else:
            raise LogParsingError('Essential log or config file not found: %s' % target)

    def parse(self):
        self.log_files, self.config_file, self.nextflow_log, self.performance_report = self.get_inputs()

        self.extractions['Dataset'] = self.extract_exact(
            'Dataset/project: ([\w \d\,\./\?\(\)\-]+)$'
        )[0]

        for log in self.log_files:
            count = None
            with LSFPreambleSkipper(log) as f:
                line = f.readline().rstrip('\n')
                if line == 'input_file_identifier,input_filename,job_index':
                    count = len(f.readlines()) - 1
            if count:
                self.extractions['# cell files'] = str(count)
                break

        for log in self.log_files:
            with LSFPreambleSkipper(log) as f:
                line = f.readline()
                match = re.match('^SPT v(\d+\.\d+\.\d+)$', line)
                if match:
                    self.extractions['SPT'] = 'v' + match.groups(1)[0]
                    break
            if 'SPT' in self.extractions:
                break

        config = json.loads(open(self.config_file, 'rt').read())
        self.extractions['Workflow'] = config['workflow'].lstrip('Multiplexed IF ')

        year = self.extract_exact('Run date year: (\d+)$', lines_limit = 2)
        if not year:
            year = 'YYYY'
        else:
            year = year[0]

        nf_header = open(self.nextflow_log, 'rt').readline().rstrip('\n')
        search = re.search('^(\w+)\-(\d+) \d+:\d+:\d+\.\d+', nf_header)
        if search:
            month = search.groups(1)[0]
            day = search.groups(1)[1]
            self.extractions['Run date'] = ' '.join([month, day, year])

        job_reports = self.extract_job_reports()
        self.extractions['Largest file size'] = str(int(sorted(
            job_reports,
            key=lambda x: -x['source file bytes'],
        )[0]['source file bytes'] / 1000000)) + 'MB'

        runtime = self.get_total_runtime()
        self.extractions['Total runtime'] = self.format_duration(runtime)

        number_cells = sum([
            job_report['number of cells'] for job_report in job_reports
        ])
        self.extractions['# cells'] = number_cells

        self.extractions['Time per 1M cells'] = self.format_duration(
            runtime / (number_cells / 1000000)
        )

        self.extractions['Longest job time'] = sorted(job_reports, key=lambda x: -x['duration']
        )[0]['duration minutes']

        self.performance_report_contents = open(self.performance_report, 'rt').read()

        message_bytes = self.performance_report_contents.encode('ascii')
        base64_bytes = base64.b64encode(message_bytes)
        self.performance_report_base64 = base64_bytes.decode('ascii')

        self.validate_all_extractions_found()

    def validate_all_extractions_found(self):
        failed = [
            key for key in LogParser.get_order()
            if (not key in self.extractions) or (self.extractions[key] in self.extractions)
        ]
        if len(failed) > 0:
            raise LogParsingError('Some extractions not made: %s' % str(sorted(failed)))

    def get_total_runtime(self):
        nf_header = open(self.nextflow_log, 'rt').readline().rstrip('\n')
        timestamp1 = self.parse_nextflow_timestamp(nf_header)
        timestamp2 = self.parse_nextflow_timestamp(self.get_last_line(self.nextflow_log))
        return self.get_timedelta(timestamp1, timestamp2)

    def parse_nextflow_timestamp(self, line):
        search = re.search('^(\w+)\-0?(\d+) 0?(\d+):0?(\d+):0?(\d+\.\d+)', line)
        if search:
            month = 0 # Not parsing month abbreviation
            day = int(search.groups(1)[1])
            hour = int(search.groups(1)[2])
            minute = int(search.groups(1)[3])
            second = float(search.groups(1)[4])
            return (month, day, hour, minute, second)
        raise LogParsingError('Could not parse Nextflow log timestamp.')

    def get_last_line(self, filename):
        with open(filename, 'rb') as f:
            try:
                f.seek(-2, os.SEEK_END)
                while f.read(1) != b'\n':
                    f.seek(-2, os.SEEK_CUR)
            except OSError:
                f.seek(0)
            last_line = f.readline().decode()        
        return last_line

    def format_duration(self, duration):
        minutes = str(
            int(10 * duration.total_seconds() / 60) / 10
        ) 
        return re.sub('\.0', '', minutes) + 'm'

    def extract_exact(self, pattern, lines_limit=1):
        for log in self.log_files:
            with LSFPreambleSkipper(log) as f:
                line_count = 0
                line = None
                while line_count < lines_limit and line != '':
                    line = f.readline()
                    line_count = line_count + 1
                    match = re.search(pattern, line)
                    if match:
                        break
                if match:
                    return match.groups(1)
        raise LogParsingError('Essential pattern not found in log files: %s' % pattern)

    def extract_job_reports(self):
        job_reports = []
        for log in self.log_files:
            with open(log, 'rt') as f:
                job_report = {}
                start_time = None
                stop_time = None
                for line in f:
                    parts = self.parse_log_line(line.rstrip('\n'))
                    if len(parts) != 0:
                        if parts['Message'] == 'Started core calculator job.':
                            start_time = (parts['Month numeric'], parts['Day numeric'], parts['Hour'], parts['Minute'], parts['Second'])
                        if parts['Message'] == 'Completed core calculator job.':
                            stop_time = (parts['Month numeric'], parts['Day numeric'], parts['Hour'], parts['Minute'], parts['Second'])
                        match = re.match('^(\d+) cells to be parsed from source file.$', parts['Message'])
                        if match:
                            job_report['number of cells'] = int(match.groups(1)[0])

                        match = re.match('^Cells source file has size (\d+) bytes.$', parts['Message'])
                        if match:
                            job_report['source file bytes'] = int(match.groups(1)[0])
                if (not start_time is None) and (not stop_time is None):
                    # Will produce misleading result at year boundary, or leap year Feb boundary
                    duration = self.get_timedelta(start_time, stop_time)
                    job_report['duration'] = duration
                    job_report['duration minutes'] = self.format_duration(duration)
                if len(job_report) > 0:
                    job_reports.append(job_report)
        if len(job_reports) == 0:
            raise LogParsingError('No job info could be extracted.')
        return job_reports

    def days_diff(self, month1, month2, day1, day2):
        if month1 == month2:
            return day2 - day1
        elif month2 > month1:
            days_in_month = [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
            return day2 + days_in_month[month1 - 1] - day1
        else:
            raise LogParsingError('Month duration unreasonable, from %s to %s.' % (month1, month2))

    def get_timedelta(self, t1, t2):
        duration = datetime.timedelta(
            days = self.days_diff(t1[0], t2[0], t1[1], t2[1]),
            hours = t2[2] - t1[2],
            minutes = t2[3] - t1[3],
            seconds = t2[4] - t1[4],
        )
        return duration

    def parse_log_line(self, line):
        ansi_cleaned = ansi_escape.sub('', line)
        pattern = '^0?(\d+)-0?(\d+) 0?(\d+):0?(\d+):0?(\d+) \[ *(\w+) *\] [\w\d\.\_]+: (.*)$'
        match = re.match(pattern, ansi_cleaned)
        parsed_tokens = {}
        if match:
            parsed_tokens['Month numeric'] = int(match.groups(1)[0])
            parsed_tokens['Day numeric'] = int(match.groups(1)[1])
            parsed_tokens['Hour'] = int(match.groups(1)[2])
            parsed_tokens['Minute'] = int(match.groups(1)[3])
            parsed_tokens['Second'] = float(match.groups(1)[4])
            parsed_tokens['Message class'] = match.groups(1)[5]
            parsed_tokens['Message'] = match.groups(1)[6]
        return parsed_tokens

    def get_extractions(self):
        return self.extractions

    @staticmethod
    def get_order():
        order = [
            'Run date',
            'Dataset',
            '# cell files',
            '# cells',
            'Largest file size',
            'Workflow',
            'SPT',
            'Total runtime',
            'Time per 1M cells',
            'Longest job time',
        ]
        return order

    def get_extractions_ordered(self):
        return [
            str(self.extractions[key]) if key in self.extractions else '*'
            for key in LogParser.get_order()
        ]


class LogReportAggregator:
    def __init__(self, format_handle=None):
        working_directories = []
        for root, dirs, files in os.walk('.'):
            if ('work' in dirs) and ('.nextflow.log' in files) and ('.spt_pipeline.json' in files):
                working_directories.append(root)
        self.parsers = [LogParser(path) for path in working_directories]
        self.format_handle = format_handle

    @staticmethod
    def get_formats():
        return [
            'tex',
            'HTML',
            'TSV',
            'JSON',
            'markdown',
        ]

    def retrieve_reports(self):
        for parser in self.parsers:
            try:
                parser.parse()
            except LogParsingError as e:
                print('Warning: Parsing error for run located at: %s' % parser.get_path(), file=sys.stderr)
                print(e, file=sys.stderr)

    def aggregate_reports_dataframe(self):
        rows = [parser.get_extractions_ordered() for parser in self.parsers]
        column_names = LogParser.get_order()
        return pd.DataFrame(rows, columns=column_names)

    def textual_render(self, format_description):
        table = self.aggregate_reports_dataframe()
        rendered = ''
        if format_description == 'tex':
            with importlib.resources.path('spatialprofilingtoolbox', 'log_table.tex.jinja') as path:
                log_report_template = open(path, 'rt').read()
                template = jinja_environment.from_string(log_report_template)
                rows = [LogParser.get_order()] + [parser.get_extractions_ordered() for parser in self.parsers]
                rendered = template.render(rows=rows)
        if format_description == 'HTML':
            with importlib.resources.path('spatialprofilingtoolbox', 'log_table.html.jinja') as path:
                log_report_template = open(path, 'rt').read()
                template = jinja_environment.from_string(log_report_template)
                rendered = template.render(
                    header=LogParser.get_order() + ['CPU usage'],
                    rows=[parser.get_extractions_ordered() + [i] for i, parser in enumerate(self.parsers)],
                    base64_contents=[parser.performance_report_base64 for parser in self.parsers],
                )
        if format_description == 'TSV':
            rendered = table.to_csv(index=False, sep='\t')
        if format_description == 'JSON':
            rendered = table.to_json(orient='records', indent=4)
        if format_description == 'markdown':
            rendered = table.to_markdown(index=False)
        return rendered.rstrip('\n')

    def report_on_all(self):
        if self.format_handle:
            print(self.textual_render(self.format_handle))
        else:
            for format_description in LogReportAggregator.get_formats():
                print('')
                print('[ ' + format_description + ' ]')
                print(self.textual_render(format_description))

if __name__=='__main__':
    args = sys.argv
    format_handle = None
    if len(args) > 1:
        if args[1] in LogReportAggregator.get_formats():
            format_handle = args[1]
            args = args[1:]
    if len(args) > 1:
        print('Optional argument one of:\n%s' % '\n'.join(
            ['  ' + f for f in LogReportAggregator.get_formats()]
        ))
        exit()
    else:
        aggregator = LogReportAggregator(format_handle=format_handle)
        aggregator.retrieve_reports()
        aggregator.report_on_all()

