# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_filters.ipynb.

# %% auto 0
__all__ = ['FilterResult', 'Filter', 'ValidityFilter', 'SingleCompoundFilter', 'AttachmentCountFilter', 'BinaryFunctionFilter',
           'DataFunctionFilter', 'RangeFunctionFilter', 'SmartsFilter', 'CatalogFilter']

# %% ../nbs/02_filters.ipynb 3
from .imports import *
from .utils import *
from .chem import Molecule, Catalog, mol_func_wrapper
from rdkit.Chem.FilterCatalog import SmartsMatcher

# %% ../nbs/02_filters.ipynb 4
class FilterResult():
    'Container for filter results'
    def __init__(self, 
                 filter_result: bool, # overall filter result (True or False)
                 filter_name:   str,  # name of filter
                 filter_data:   dict  # filter data dict
                ):
        
        self.filter_result = filter_result
        self.filter_name = filter_name
        self.filter_data = filter_data
        
    def __repr__(self):
        return f'{self.filter_name} result: {self.filter_result}'

class Filter():
    'Filter base class'
    def __init__(self, name='filter' # filter name
                ):
        self.name = name
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        return FilterResult(True, self.name, {})
    
    def __repr__(self):
        return self.name

# %% ../nbs/02_filters.ipynb 5
class ValidityFilter(Filter):
    'Checks if molecule is valid'
    def __init__(self):
        self.name = 'validity_filter'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        return FilterResult(molecule.valid, self.name, {})
    
class SingleCompoundFilter(Filter):
    'Checks if molecule is a single compound'
    def __init__(self):
        self.name = 'single_compound'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result = not ('.' in molecule.smile)
        return FilterResult(result, self.name, {})

# %% ../nbs/02_filters.ipynb 7
class AttachmentCountFilter(Filter):
    'Checks number of dummy attachment atoms'
    def __init__(self, 
                 num_attachments: int):
        
        self.num_attachments = num_attachments
        self.name = f'attachment_count_{num_attachments}'
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        num_attachments = molecule.smile.count('*')
        
        result = num_attachments == self.num_attachments
        data = {'num_attachments' : num_attachments}
        
        return FilterResult(result, self.name, data)


# %% ../nbs/02_filters.ipynb 9
class BinaryFunctionFilter(Filter):
    def __init__(self, 
            func: Callable[[Molecule], bool], # callable function that takes a Molecule as input and returns a bool
            name: str # filter name
                ):
        'Filters based on the result of `func`'
        
        self.name = name
        self.func = func
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result = self.func(molecule)
        
        return FilterResult(result, self.name, {})
    
class DataFunctionFilter(Filter):
    def __init__(self, 
            func: Callable[[Molecule], Tuple[bool, dict]], # callable that takes a Molecule and returns (bool, dict)
            name: str # filter name
                ):
        "Filters based on the result of `func`. Data from function is added to the filter result"
        
        self.name = name
        self.func = func
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        result, data = self.func(molecule)
        
        return FilterResult(result, self.name, data)

# %% ../nbs/02_filters.ipynb 11
class RangeFunctionFilter(Filter):
    def __init__(self, 
                 func:    Callable[[Molecule], Union[int, float]], # callable function, takes a Molecule as input, returns a numeric value
                 name:    str, # filter name
                 min_val: Union[int, float, None]=None, # min acceptable range value (if None, defaults to -inf)
                 max_val: Union[int, float, None]=None  # max acceptable range value (if None, defaults to inf)
                ):
        
        '''
        `RangeFunctionFilter` passes a `Molecule` to `func`, then checks if the output is 
        between `min_val` and `max_val`
        '''
        
        min_val, max_val = validate_range(min_val, max_val, float('-inf'), float('inf'))
        
        self.func = func
        self.min_val = min_val
        self.max_val = max_val
        self.name = name
                
    def __call__(self, molecule: Molecule) -> FilterResult:
        value = self.func(molecule)
        data = {'computed_value' : value, 'min_val' : self.min_val, 'max_val' : self.max_val}
        result = self.min_val <= value <= self.max_val
        
        return FilterResult(result, self.name, data)

# %% ../nbs/02_filters.ipynb 13
class SmartsFilter(Filter):
    def __init__(self, 
                 smarts:  str, # SMARTS string 
                 name:    str, # filter name
                 exclude: bool=True, # if filter should be exclusion or inclusion
                 min_val: Union[int, float, None]=None, # min number of occurences 
                 max_val: Union[int, float, None]=None # max number of occurences 
                ): 
        
        '''
        `SmartsFilter` checks to see if `smarts` is present in a Molecule. If 
        `min_val` and `max_val` are passed, the filter will check to see if the number 
        of occurences are between those values. If `exclude=True`, the filter will 
        fail molecules that match the filter. Otherwise, filter will fail molecules 
        that don't match the filter
        '''
        
        min_val, max_val = validate_range(min_val, max_val, 1, int(1e8))
        
        self.smarts = smarts
        self.name = name
        self.exclude = exclude
        self.min_val = min_val
        self.max_val = max_val
        self.smarts_matcher = SmartsMatcher(self.name, self.smarts, self.min_val, self.max_val)
        
    def has_match(self, molecule: Molecule) -> bool:
        return self.smarts_matcher.HasMatch(molecule.mol)
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        
        has_match = self.has_match(molecule)
        result = not has_match if self.exclude else has_match
        data = {'filter_result' : has_match}
        
        return FilterResult(result, self.name, data)

# %% ../nbs/02_filters.ipynb 15
class CatalogFilter(Filter):
    def __init__(self, 
                 catalog: Catalog, # SMARTS catalog
                 name:    str, # filter name
                 exclude: bool=True # if filter should be exclusion or inclusion
                ):
        
        '''
        `CatalogFilter` checks to see if a molecule has a match against the provided `Catalog`. 
        If `exclude=True`, matching molecules fail the filter. Otherwise, matching molecules will pass
        '''
        
        self.catalog = catalog
        self.name = name
        self.exclude = exclude
        
    def has_match(self, molecule: Molecule) -> bool:
        return self.catalog.has_match(molecule)
        
    def __call__(self, molecule: Molecule) -> FilterResult:
        has_match = self.has_match(molecule)
        result = not has_match if self.exclude else has_match
        data = {'filter_result' : has_match}
        
        return FilterResult(result, self.name, data)
