import numpy as np
from ..data.data import Data
from ..model.model import Model
from .statistic import Statistic



class Pair(object):
    
    _allowed_stats = {'gstat': Statistic.Gstat, 
                      'chi2': Statistic.Gstat, 
                      'pstat': Statistic.Pstat, 
                      'ppstat': Statistic.PPstat, 
                      'cstat': Statistic.PPstat, 
                      'pgstat': Statistic.PGstat, 
                      'Xppstat': Statistic.PPstat_Xspec, 
                      'Xcstat': Statistic.PPstat_Xspec, 
                      'Xpgstat': Statistic.PGstat_Xspec, 
                      'ULppstat': Statistic.PPstat_UL, 
                      'ULpgstat': Statistic.PGstat_UL}

    def __init__(self, data, model):
        
        self._data = data
        self._model = model
        
        self._pair()
        
        
    @property
    def data(self):
        
        return self._data
    
    
    @data.setter
    def data(self, new_data):
        
        self._data = new_data
        
        self._pair()


    @property
    def model(self):
        
        return self._model
    
    
    @model.setter
    def model(self, new_model):
        
        self._model = new_model
        
        self._pair()
        
        
    def _pair(self):
        
        if not isinstance(self.data, Data):
            raise ValueError('data argument should be Data type')
        
        if not isinstance(self.model, Model):
            raise ValueError('model argument should be Model type')
        
        self.data.fit_with = self.model
        
        
    def _convolve(self):
        
        flat_phtflux = self.model.integ(self.data.ebin, self.data.tarr)
        phtflux = [flat_phtflux[i:j].copy() for (i, j) in zip(self.data.bin_start, self.data.bin_stop)]
        ctsrate = [np.dot(pf, drm) for (pf, drm) in zip(phtflux, self.data.corr_rsp_drm)]
        
        return ctsrate
    
    
    def _re_convolve(self):
        
        flat_phtflux = self.model.integ(self.data.ebin, self.data.tarr)
        phtflux = [flat_phtflux[i:j].copy() for (i, j) in zip(self.data.bin_start, self.data.bin_stop)]
        re_ctsrate = [np.dot(pf, drm) for (pf, drm) in zip(phtflux, self.data.corr_rsp_re_drm)]
        
        return re_ctsrate


    @property
    def conv_ctsrate(self):
        
        return self._convolve()
    
    
    @property
    def conv_re_ctsrate(self):
        
        return self._re_convolve()


    @property
    def conv_ctsspec(self):
        
        return [cr / chw for (cr, chw) in zip(self.conv_ctsrate, self.data.rsp_chbin_width)]
    
    
    @property
    def conv_re_ctsspec(self):
        
        return [cr / chw for (cr, chw) in zip(self.conv_re_ctsrate, self.data.rsp_re_chbin_width)]
        
        
    @property
    def phtspec_at_rsp(self):
        
        return [self.model.phtspec(E, T) for (E, T) in \
            zip(self.data.rsp_chbin_mean, self.data.rsp_chbin_tarr)]
        
        
    @property
    def re_phtspec_at_rsp(self):
        
        return [self.model.phtspec(E, T) for (E, T) in \
            zip(self.data.rsp_re_chbin_mean, self.data.rsp_re_chbin_tarr)]
        
        
    @property
    def cts_to_pht(self):
        
        return [cts / pht for (cts, pht) in zip(self.conv_ctsspec, self.phtspec_at_rsp)]
    
    
    @property
    def re_cts_to_pht(self):
        
        return [cts / pht for (cts, pht) in zip(self.conv_re_ctsspec, self.re_phtspec_at_rsp)]
    
    
    @property
    def cts_to_flux(self):
        
        ctsrate = [np.sum(cr) for cr in self.data.net_ctsrate]
        ergflux = [np.sum([self.model.ergflux(emin, emax, 1000) for emin, emax in notc])
                   for notc in self.data.notcs]
        
        return [flux / cts for (flux, cts) in zip(ergflux, ctsrate)]
    
    
    @property
    def conv_cts_to_flux(self):
        
        ctsrate = [np.sum(cr) for cr in self.conv_ctsrate]
        ergflux = [np.sum([self.model.ergflux(emin, emax, 1000) for emin, emax in notc])
                   for notc in self.data.notcs]
        
        return [flux / cts for (flux, cts) in zip(ergflux, ctsrate)]
    
    
    def cts_to_fluxdensity(self, at=1, unit='fv'):
        
        ctsrate = [np.sum(cr) for cr in self.data.net_ctsrate]
        if unit == 'NE':
            fluxdensity = self.model.phtspec(at)
        elif unit == 'fv':
            fluxdensity = self.model.flxspec(at)
        elif unit == 'Jy':
            fluxdensity = self.model.flxspec(at) * 1e6 / 2.416
        else:
            raise ValueError(f'unsupported value of unit: {unit}')
            
        return [fluxdensity / cts for cts in ctsrate]
    
    
    def conv_cts_to_fluxdensity(self, at=1, unit='fv'):
        
        ctsrate = [np.sum(cr) for cr in self.conv_ctsrate]
        if unit == 'NE':
            fluxdensity = self.model.phtspec(at)
        elif unit == 'fv':
            fluxdensity = self.model.flxspec(at)
        elif unit == 'Jy':
            fluxdensity = self.model.flxspec(at) * 1e6 / 2.416
        else:
            raise ValueError(f'unsupported value of unit: {unit}')
            
        return [fluxdensity / cts for cts in ctsrate]


    @property
    def stat_func(self):
        
        return lambda S, B, m, ts, tb, sigma_S, sigma_B, stat: \
            np.inf if np.isnan(m).any() or np.isinf(m).any() else \
                self._allowed_stats[stat](**{'S': np.float64(S), 'B': np.float64(B), \
                    'm': np.float64(m), 'ts': np.float64(ts), 'tb': np.float64(tb), \
                    'sigma_S': np.float64(sigma_S), 'sigma_B': np.float64(sigma_B)})


    def _stat_calculate(self):
        
        return np.array(list(map(self.stat_func, 
                                 self.data.src_counts, 
                                 self.data.bkg_counts, 
                                 self.model.conv_ctsrate, 
                                 self.data.corr_src_efficiency, 
                                 self.data.corr_bkg_efficiency, 
                                 self.data.src_errors, 
                                 self.data.bkg_errors, 
                                 self.data.stats))).astype(float)


    @property
    def stat_list(self):
        
        return self._stat_calculate()
    
    
    @property
    def weight_list(self):
        
        return self.data.weights


    @property
    def stat(self):
        
        return np.sum(self.stat_list * self.weight_list)
    
    
    @property
    def loglike_list(self):
        
        return -0.5 * self.stat_list
    
    
    @property
    def loglike(self):
        
        return -0.5 * self.stat
    
    
    @property
    def npoint_list(self):
        
        return self.data.npoints
    
    
    @property
    def npoint(self):
        
        return np.sum(self.npoint_list)
