import collections
from nltk import NaiveBayesClassifier, DecisionTreeClassifier
from nltk.metrics import precision, recall, f_measure
from nltk.classify import apply_features, accuracy
from nltk.classify.scikitlearn import SklearnClassifier
from prueba_paquete.utils import clean_html_tags, shuffled, tokenize_and_stem
from prueba_paquete.concept_extraction import ConceptExtractor
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction import DictVectorizer


class DocumentClassifier():
    '''
    Train a classifier with labeled documents and classify new documents 
    into one of the labeled clases.
    We call 'dev docs' to the documents set provided for training the 
    classifier. These 'dev docs' are splitted into two sub sets: 'train docs' 
    and 'test docs' that would be used to train and test the machine learning
    model respectively.

    Parameters
    ----------
    train_p : float, 0.8 by default
        The proportion of the 'dev docs' used as 'train docs'
        Use values greater than 0 and lower than 1.
        The remaining docs will be using as 'test docs'
    
    eq_label_num : boolean, True by default
        If true, 'train docs' will have equal number of documents for each
        class. This number will be the lowest label count.
    
    complete_p : boolean, True by default
        Used when eq_label_num is True, but the lowest label count is not
        enough for getting the train_p proportion of 'train docs'. If this 
        attribute is True, more documents from 'test docs' will be moved
        to 'train docs' until we get train_p

    n_folds : integer, 10 by default
        Number of folds to be used in k-fold cross validation technique for
        choosing different sets as 'train docs'

    vocab_size : integer, 500 by default
        This is the size of the vocabulary set that will be used for extracting
        features out of the docs

    t_classifier : string, 'NB' by default
        This is the type of classifier model used. Available types are 'NB' 
        (Naive Bayes), 'DT' (decision tree), 'RF' (Random Forest), and 'SVM'
        (Support Vector Machine)

    language: string, 'english'; by default
        Language on which documents are written
    '''

    def __init__(self, train_p=0.8, eq_label_num=True,  
                 complete_p=True, n_folds=10,
                 vocab_size=250, 
                 t_classifier="NB", language="english", 
                 stem=False):
        self.train_p = train_p
        self.eq_label_num = eq_label_num
        self.complete_p = complete_p
        self.n_folds = n_folds
        self.vocab_size = vocab_size
        self.t_classifier = t_classifier
        self.language = language
        self.stem = stem
        self._vocab = []
        self._classified_docs = []
        self._classifier = None
        self._accuracy = 0
        self._precision = {}
        self._recall = {}
        self._f_measure = {}
        self._train_docs = []
        self._test_docs = []

    def split_train_and_test(self, docs):
        '''
        Split the 'dev docs' set into the 'train docs' and 'test docs' subsets

        Parameters
        ----------
        docs: iterable
            An iterable which yields a list of strings

        '''

        categories_count = self.count_categories(docs)
        label_limit = min([c for (k,c) in categories_count.items()])
        labeled_docs = {}
        train_docs = []
        test_docs = []
        # Split docs by label
        for (cat,count) in categories_count.items():
            labeled_docs[cat] = shuffled([t for (t,k) in docs if k == cat])
        if self.eq_label_num:
            # Select the same number of doc for all labels
            for cat, cat_docs in labeled_docs.items():
                cat_limit = label_limit
                cat_train_docs = cat_docs[:cat_limit]
                cat_test_docs = cat_docs[cat_limit:]
                train_docs += [(doc, cat) for doc in cat_train_docs]
                test_docs += [(doc, cat) for doc in cat_test_docs]
            l_train = len(train_docs)
            l_docs = len(docs)
            l_test = len(test_docs)
            actual_p = l_train / l_docs
            # If the training proportion is not 
            if self.complete_p == True and actual_p < self.train_p:
                shuffled_extra = shuffled(test_docs)
                extra_i = 0
                while(actual_p < self.train_p and extra_i < l_test):
                    aux_l_train = l_train + extra_i
                    actual_p = aux_l_train / l_docs
                    extra_i += 1
                train_docs += shuffled_extra[:extra_i]
                test_docs = shuffled_extra[extra_i:]
        else:
            label_limit = int(self.train_p * len(docs))
            shuffled_docs = shuffled(docs)
            train_docs = shuffled_docs[:label_limit]
            test_docs = shuffled_docs[label_limit:]
        self._train_docs = train_docs
        self._test_docs = test_docs

    def cross_validation_train(self, dev_docs):
        '''
        Applies k-fold cross validation technique to split the docs into different
        pairs of training and testing sets. For each pair, it trains and evals the
        a classifier, choosing the one with the best accuracy

        Parameters
        ----------
        dev_docs: iterable
            An iterable which yields a list of strings

        '''
        dev_docs = shuffled(dev_docs)
        accuracies = []
        best_accuracy = 0
        subset_size = int(len(dev_docs)/self.n_folds)

        for i in range(self.n_folds):
            classifier_list = []
            train_docs = (dev_docs[(i + 1) * subset_size:] + \
                          dev_docs[:i * subset_size])
            test_docs = dev_docs[i * subset_size:(i + 1) * subset_size]
            train_set = apply_features(self.get_doc_features, train_docs)
            if self.t_classifier == "NB":
                classifier = NaiveBayesClassifier.train(train_set)
            elif self.t_classifier == "DT":
                classifier = DecisionTreeClassifier.train(train_set)
            elif self.t_classifier == "RF":
                classifier = SklearnClassifier(RandomForestClassifier())\
                                                       .train(train_set)
            elif self.t_classifier == "SVM":
                classifier = SklearnClassifier(LinearSVC(), sparse=False)\
                                                         .train(train_set)

            classifier_list.append(classifier)
            test_set = apply_features(self.get_doc_features, test_docs, True)
            accuracies.append((accuracy(classifier, test_set)) * 100)

            if accuracies[-1] > best_accuracy:
                best_accuracy = accuracies[-1]
                self._classifier = classifier
                self._train_docs = train_docs
                self._test_docs = test_docs
    
    def equitative_class_train(self, dev_docs):
        categories_count = self.count_categories(dev_docs)
        
        labeled_docs = {}
        for (cat,count) in categories_count.items():
            labeled_docs[cat] = shuffled([t for (t,k) in dev_docs if k == cat])

        train_docs = []
        test_docs = []

        for cat, l in labeled_docs.items():
            cat_limit = int(self.train_p * len(l))
            train_docs += [(t, cat) for t in l[:cat_limit]]
            test_docs += [(t, cat) for t in l[cat_limit:]]

        self._train_docs = train_docs
        self._test_docs = test_docs

        # print("len dev docs", len(dev_docs))
        # print("categories count", categories_count)
        # print("count train", self.count_categories(train_docs))
        # print("count test", self.count_categories(test_docs))

        

        # split dev docs and create traning and test set	
        # self.split_train_and_test(dev_docs)	
        train_set = apply_features(self.get_doc_features, self._train_docs)	
        # create and train the classification model according to t_classifier	
        if self.t_classifier == "NB":	
            self._classifier = NaiveBayesClassifier.train(train_set)	
        elif self.t_classifier == "DT":	
            self._classifier = DecisionTreeClassifier.train(train_set)	
        elif self.t_classifier == "RF":	
            self._classifier = SklearnClassifier(RandomForestClassifier())\
                                                         .train(train_set)	
        elif self.t_classifier == "SVM":	
            self._classifier = SklearnClassifier(LinearSVC(), sparse=False)\
                                                          .train(train_set)
    
    def count_categories(self, docs):
        '''
        Count how many documents of each class are in the 'dev docs' set
        
        Parameters
        ----------
        docs: iterable
            An iterable which yields a list of strings

        Returns
        -------
        counters: dictionary
            A dictiionary where each item is the number of docs for a class
        '''

        categories = set([c for (t,c) in docs])
        counters = {}
        for cat in categories:
            counters[cat] = 0
        for (text, cat) in docs:
            counters[cat] += 1
        self._categories = sorted(categories)
        return counters

    def get_doc_features(self, doc):
        '''
        Extract features of a document, checking the presence of the words
        in the vocabulary

        Parameters
        ----------
        doc: string
            The doc from which features will be extracted

        Returns
        -------
        features: dictionary
            A dictionary where each item indicates the presence of a
            word from the vocabulary in the input doc
        '''

        features = {}
        for word in self._vocab:
            features['contains({})'.format(word)] = (word in doc)
        return features


    def train_classifier(self, dev_docs):
        '''
        Create the features vocabulary from 'dev docs', 
        Split 'dev docs', train the classifier with 'train docs',
        Evaluate accuracy with 'test docs'

        Parameters
        ----------
        dev_docs: iterable
            An iterable which yields a list of strings
        '''
        # create vocabulary for feature extraction
        ce = ConceptExtractor(num_concepts=self.vocab_size, 
                              language=self.language)
        ce.extract_concepts([t for (t,c) in dev_docs])
        self._vocab = sorted([c for (c,f) in ce.common_concepts], key=str.lower)
        if (self.stem):
            self._vocab = [tokenize_and_stem(w, language=self.language)[0] \
                                                    for w in self._vocab]
        # self.cross_validation_train(dev_docs)
        self.equitative_class_train(dev_docs)


    def eval_classifier(self):
        '''
        Test the model and calculates the metrics of accuracy, precision,
        recall and f-measure
        '''
        test_set = apply_features(self.get_doc_features, self._test_docs, True)
        self._accuracy = accuracy(self._classifier, test_set)
        refsets = collections.defaultdict(set)
        testsets = collections.defaultdict(set)
        
        for i, (feats, label) in enumerate(test_set):
            refsets[label].add(i)
            observed = self._classifier.classify(feats)
            testsets[observed].add(i)
        self.count_categories(self._train_docs)
        for cat in self._categories:
            self._precision[cat] = precision(refsets[cat], testsets[cat])
            self._recall[cat] = recall(refsets[cat], testsets[cat])
            self._f_measure[cat] = f_measure(refsets[cat], testsets[cat])


    def classify_docs(self, docs):
        '''
        First train the classifier with the labeled data.
        Then classifies the unlabeled data.

        Parameters
        ----------
        docs: iterable
            An iterable which yields a list of strings
        '''

        dev_docs = [(t, c) for (t, c) in docs if c!=""]
        unlabeled_docs = [t for (t, c) in docs if c==""]
        self.train_classifier(dev_docs)
        self.eval_classifier()
        results = []
        for doc in unlabeled_docs:
            doc_feats = self.get_doc_features(doc)
            result = self._classifier.classify(doc_feats)
            results.append((doc, result))
        self._classified_docs = results
        self._final_cat_count = self.count_categories(dev_docs+results)
    
    @property
    def classified_docs(self):
        return self._classified_docs

    @property    
    def accuracy(self):
        return self._accuracy
    
    @property
    def precision(self):
        return self._precision

    @property
    def recall(self):
        return self._recall

    @property
    def f_measure(self):
        return self._f_measure

    @property
    def category_count(self):
        return self._final_cat_count
