# Author: Cameron F. Abrams, <cfa22@drexel.edu>

"""
This module contains functions for validating PSF/PDB files generated by Pestifer.  The user specifies VMD atomselections, variables to get from those selections and values those variables should have.
"""

import logging

from pathlib import Path
from typing import ClassVar
from pestifer.core.artifacts import StateArtifacts, VMDLogFileArtifact, VMDScriptArtifact, DataArtifact
from pestifer.logparsers.logparser import VMDLogParser
from pestifer.scripters.vmdscripter import VMDScripter
from pestifer.tasks.basetask import VMDTask

logger = logging.getLogger(__name__)

class AttributeTest:
    """
    This class represents a test for expected values of specific
    atom attributes in an atom selection.

    Attributes
    ----------
    name : str
        The name of the test.
    selection : str
        The atom selection string.
    attribute : str
        The atom attribute to test.
    value : str
        The expected value of the attribute.
    value_count : int
        The number of expected occurrences of the value.
    """
    def __init__(self, name: str, selection: str, attribute: str, value: str, value_count: int = 1):
        self.name = name
        self.selection = selection
        self.attribute = attribute
        self.value = value
        self.value_count = value_count

    def write(self, vt: VMDScripter):
        """
        Write the commands to the VMD script to execute this test.
        """
        vt.addline(f'set test_selection [atomselect top "{self.selection}"]')
        vt.addline(f'set result [$test_selection get {self.attribute}]')
        target = self.value
        target_count = self.value_count
        vt.addline(f'set count [llength [lsearch -nocase -exact -all $result {target}]]')
        vt.addline(f'if {{$count != {target_count}}} {{')
        vt.addline(f'    vmdcon "FAIL {self.attribute} has unexpected count $count of value {target} (expected {target_count}) in selection {self.selection}"')
        vt.addline(f'}} else {{')
        vt.addline(f'    vmdcon "PASS {self.attribute} has expected count $count of value {target} (expected {target_count}) in selection {self.selection}"')
        vt.addline(f'}}')

class ConnectionTest:
    """
    This class represents a test for expected number of bonds of a specific type among an atomselection.

    Attributes
    ----------
    name : str
        The name of the test.
    selection : str
        The atom selection string.
    connection_type : str
        The type of connection to test (e.g., "interresidue").
    connection_count : int
        The expected number of connections.
    """
    connection_type_supported = {'interresidue', 'disulfide', 'glycosylation'}

    def __init__(self, name: str, selection: str, connection_type: str, connection_count: int = 1):
        self.name = name
        self.selection = selection
        self.connection_type = connection_type
        if not self.connection_type in self.connection_type_supported:
            raise NotImplementedError(f'Unsupported connection type: {self.connection_type}')
        self.connection_count = connection_count

    def write(self, vt: VMDScripter):
        vt.addline(f'set test_selection [atomselect top "{self.selection}"]')
        match self.connection_type:
            case 'interresidue':
                vt.addline(f'set indexes [$test_selection get index]')
                vt.addline(f'set atomnames [$test_selection get name]')
                vt.addline(f'set resids [$test_selection get residue]')
                vt.addline(f'set bondlists [$test_selection getbonds]')
                vt.addline(f'set count 0')
                vt.addline(f'foreach index $indexes atomname $atomnames resid $resids bondlist $bondlists {{')
                vt.addline(f'   foreach partner $bondlist {{')
                vt.addline(f'      if {{$partner < $index}} {{ continue }}') # no double-counting!
                vt.addline(f'      set idx [lsearch $indexes $partner]')
                vt.addline(f'      set partner_resid [lindex $resids $idx]')
                vt.addline(f'      if {{$partner_resid != "" && $partner_resid != $resid}} {{')
                vt.addline(f'         set partnersel [atomselect top "index $partner"]')
                vt.addline(f'         set partner_name [$partnersel get name]')
                vt.addline(f'         vmdcon "BOND $index $atomname $resid <--> $partner $partner_name $partner_resid"')
                vt.addline(f'         incr count')
                vt.addline(f'      }}')
                vt.addline(f'   }}')
                vt.addline(f'}}')
                vt.addline(f'if {{$count != {self.connection_count}}} {{')
                vt.addline(f'   vmdcon "FAIL {self.name} has unexpected count $count (expected {self.connection_count}) in selection {self.selection} for interresidue bonds"')
                vt.addline(f'}} else {{')
                vt.addline(f'   vmdcon "PASS {self.name} has expected count $count (expected {self.connection_count}) in selection {self.selection}"')
                vt.addline(f'}}')
            case 'disulfide':
                vt.addline(f'set indexes [$test_selection get index]')
                vt.addline(f'set atomnames [$test_selection get name]')
                vt.addline(f'set resids [$test_selection get residue]')
                vt.addline(f'set bondlists [$test_selection getbonds]')
                vt.addline(f'set count 0')
                vt.addline(f'foreach index $indexes atomname $atomnames resid $resids bondlist $bondlists {{')
                vt.addline(f'   foreach partner $bondlist {{')
                vt.addline(f'      if {{$partner < $index}} {{ continue }}') # no double-counting!
                vt.addline(f'      set idx [lsearch $indexes $partner]')
                vt.addline(f'      set partner_resid [lindex $resids $idx]')
                vt.addline(f'      if {{$partner_resid != "" && $partner_resid != $resid}} {{')
                vt.addline(f'         set partnersel [atomselect top "index $partner"]')
                vt.addline(f'         set partner_name [$partnersel get name]')
                vt.addline(f'         vmdcon "BOND $index $atomname $resid <--> $partner $partner_name $partner_resid"')
                vt.addline(f'         if {{$atomname == "SG" && $partner_name == "SG"}} {{')
                vt.addline(f'             incr count')
                vt.addline(f'         }}')
                vt.addline(f'      }}')
                vt.addline(f'   }}')
                vt.addline(f'}}')
                vt.addline(f'if {{$count != {self.connection_count}}} {{')
                vt.addline(f'   vmdcon "FAIL {self.name} has unexpected count $count (expected {self.connection_count}) in selection {self.selection} for interresidue bonds"')
                vt.addline(f'}} else {{')
                vt.addline(f'   vmdcon "PASS {self.name} has expected count $count (expected {self.connection_count}) in selection {self.selection}"')
                vt.addline(f'}}')
            case 'glycosylation':
                vt.addline(f'set indexes [$test_selection get index]')
                vt.addline(f'set resids [$test_selection get residue]')
                vt.addline(f'set bondlists [$test_selection getbonds]')
                vt.addline(f'set external_partners [list]')
                vt.addline(f'foreach index $indexes resid $resids bondlist $bondlists {{')
                vt.addline(f'   foreach partner $bondlist {{')
                vt.addline(f'      set idx [lsearch $indexes $partner]')
                vt.addline(f'      if {{$idx == -1}} {{')
                vt.addline(f'         lappend external_partners $partner')
                vt.addline(f'      }}')
                vt.addline(f'   }}')
                vt.addline(f'}}')
                vt.addline(f'set extsel [atomselect top "index $external_partners"]')
                vt.addline(f'set names [$extsel get name]')
                vt.addline(f'set count [llength [lsearch -nocase -exact -all $names C1]]')
                vt.addline(f'if {{$count != {self.connection_count}}} {{')
                vt.addline(f'    vmdcon "FAIL Selection {self.selection} has no atoms bound to an external C1 atom"')
                vt.addline(f'}} else {{')
                vt.addline(f'    vmdcon "PASS Selection {self.selection} has $count atoms bound to an external C1 atom"')
                vt.addline(f'}}')
            case '_':
                logger.debug(f'Unsupported connection type: {self.connection_type}')

class ResidueTest:
    """
    This class represents a test for expected properties of a specific residue in an atom selection.
    """
    def __init__(self, name: str, selection: str, measure: str, value: int, relation: str = '=='):
        self.name = name
        self.selection = selection
        self.measure = measure
        self.relation = relation
        self.value = value

    def write(self, vt: VMDScripter):
        vt.addline(f'set test_selection [atomselect top "{self.selection}"]')
        pass_msg = f'PASS {self.name} has expected relation {self.relation} to count $count (expected {self.value}) in selection {self.selection}'
        fail_msg = f'FAIL {self.name} does not have expected relation {self.relation} to count $count (expected {self.value}) in selection {self.selection}'
        match self.measure:
            case 'atom_count':
                vt.addline(f'set count [$test_selection num]')
                vt.addline(f'if {{$count {self.relation} {self.value}}} {{')
                vt.addline(f'   vmdcon "{pass_msg}"')
                vt.addline(f'}} else {{')
                vt.addline(f'   vmdcon "{fail_msg}"')
                vt.addline(f'}}')
            case 'residue_count':
                vt.addline(f'set resids [$test_selection get residue]')
                vt.addline(f'set unique_resids [lsort -unique $resids]')
                vt.addline(f'set count [llength $unique_resids]')
                vt.addline(f'if {{$count {self.relation} {self.value}}} {{')
                vt.addline(f'   vmdcon "{pass_msg}"')
                vt.addline(f'}} else {{')
                vt.addline(f'   vmdcon "{fail_msg}"')
                vt.addline(f'}}')
            case '_':
                logger.debug(f'Unsupported measure type: {self.measure}')

class ValidateTask(VMDTask):
    """
    This class represents a validation task for PSF/PDB files.  A validate
    task can be inserted anywhere in the workflow to perform validation checks.
    """

    _yaml_header: ClassVar[str] = 'validate'

    def provision(self, packet: dict = {}):
        super().provision(packet)
        self.test_specs = self.specs.get('tests', [])
        self.attribute_tests: list[AttributeTest] = []
        self.connection_tests: list[ConnectionTest] = []
        self.residue_tests: list[ResidueTest] = []
        for test in self.test_specs:
            logger.debug(f'Processing test specification: {test}')
            assert isinstance(test, dict), f"Test specification {test} must be a dictionary"
            assert len(test) == 1, f"Test specification {test} must have a single key-value pair"
            test_type = list(test.keys())[0]
            match test_type:
                case 'attribute_test':
                    specs = test['attribute_test']
                    assert isinstance(specs, dict), f"Attribute test specification {specs} must be a dictionary"
                    self.attribute_tests.append(AttributeTest(**specs))
                case 'connection_test':
                    specs = test['connection_test']
                    assert isinstance(specs, dict), f"Connection test specification {specs} must be a dictionary"
                    self.connection_tests.append(ConnectionTest(**specs))
                case 'residue_test':
                    specs = test['residue_test']
                    assert isinstance(specs, dict), f"Residue test specification {specs} must be a dictionary"
                    self.residue_tests.append(ResidueTest(**specs))
        logger.debug(f'Provisioned {len(self.attribute_tests)} Attribute tests: {self.attribute_tests}')
        logger.debug(f'Provisioned {len(self.connection_tests)} Connection tests: {self.connection_tests}')
        logger.debug(f'Provisioned {len(self.residue_tests)} Residue tests: {self.residue_tests}')

    def do(self):
        """
        Execute the validation task.
        """
        state: StateArtifacts = self.get_current_artifact('state')
        psf: Path = state.psf
        pdb: Path = state.pdb
        vt: VMDScripter = self.get_scripter('vmd')
        self.next_basename()
        vt.newscript(basename=self.basename)
        vt.load_psf_pdb(psf.name, pdb.name)
        for attribute_test in self.attribute_tests:
            attribute_test.write(vt)
        for connection_test in self.connection_tests:
            connection_test.write(vt)
        for residue_test in self.residue_tests:
            residue_test.write(vt)
        # here we insert the commands to extract values of desired variables
        vt.writescript()
        self.register(self.basename, key = 'tcl', artifact_type=VMDScriptArtifact)
        vt.runscript()
        log_artifact = self.register(self.basename, key='log', artifact_type=VMDLogFileArtifact)
        self.log = VMDLogParser.from_file(log_artifact.name)
        results = self.log.collect_validation_results()
        npass = 0
        nfail = 0
        if not results:
            logger.debug(f'Empty validation results')
            self.extra_message = "No validation results found."
        else:
            npass = sum(1 for r in results if 'PASS' in r)
            nfail = sum(1 for r in results if 'FAIL' in r)
            self.register(dict(npass=npass, nfail=nfail), key='validation_results')
            logger.debug(f'Validation results: \x1b[32m\x1b[1m{npass} passing\x1b[0m, \x1b[31m\x1b[1m{nfail} failing\x1b[0m>')
            self.extra_message = f"\x1b[32m\x1b[1mpass: {npass}\x1b[0m, \x1b[31m\x1b[1mfail: {nfail}\x1b[0m"
        # here we would parse the resulting log file
        return nfail
