    
import math
import random
import parmap
import itertools
from itertools import repeat
from scipy.spatial import distance
import operator
from tqdm import tqdm
from functools import reduce
from sklearn import mixture
import statsmodels.stats.multitest as multi
import networkx as nx
import multiprocessing as mp
import numpy as np
from scipy.stats import poisson
import pandas as pd
import pygco as pygco # cut_from_graph # pip install git+git://github.com/amueller/gco_python
import scipy.stats.mstats as ms
from itertools import repeat
from scipy.stats import norm
from scipy.stats import binom
from scipy.stats import skewnorm
from scipy.sparse import issparse
import matplotlib as mpl
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.path as mplPath
import shapely.geometry
import shapely.ops
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay, KDTree, ConvexHull
from matplotlib.patches import Polygon
from matplotlib.collections import LineCollection, PatchCollection
from PIL import Image
from matplotlib.backends.backend_pdf import PdfPages
import sklearn.manifold as manifold
import sklearn.decomposition as decomposition 
from sklearn.preprocessing import StandardScaler
from scipy.stats import gaussian_kde
from sklearn.cluster import DBSCAN,KMeans
from scipy.stats import ttest_ind
from sklearn.utils import shuffle
import hdbscan
import seaborn as sns
import sklearn.cluster as cluster
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
from scipy import stats
import sys
import os
import time
from scipy.stats import skewnorm
from scipy import stats
from ast import literal_eval
import pysal
from pysal.explore.esda.moran import Moran
from skimage.filters import threshold_otsu

from .GMM import *


def create_graph_with_weight(points, normCount):
    '''
    Returns a graph created from cell coordiantes.
    edge weights set by normalized counts.
    
    :param points: shape (n,2); normCount: shape (n)
    :rtype: ndarray shape (n ,3)
    
    '''
    edges = {}   
    var = normCount.var()
    delauny = Delaunay(points)
    
## 1. get original all connected edges and remove repetitives edges among triangle
    cellGraph = np.zeros((points.shape[0]*10, 4))
    for simplex in delauny.simplices:
        simplex.sort()
        edge0 = str(simplex[0]) + " " + str(simplex[1])
        edge1 = str(simplex[0]) + " " + str(simplex[2])
        edge2 = str(simplex[1]) + " " + str(simplex[2])
        edges[edge0] = 1
        edges[edge1] = 1
        edges[edge2] = 1
        
    i = 0
    for kk in edges.keys():  
        node0 = int(kk.split(sep=" ")[0])
        node1 = int(kk.split(sep=" ")[1])
        edgeDiff = normCount[node0] - normCount[node1]
        energy = np.exp((0 - edgeDiff**2)/(2*var))
        dist = distance.euclidean(points[node0,:], points[node1,:])
        cellGraph[i] = [node0, node1, energy, dist]       
        i = i + 1
    tempGraph = cellGraph[0:i]
    
## 2. remove more than cutoff edges from longest edges of each simplex taiangle. 
    n_components_range = range(1,5)
    best_component = 1
    lowest_bic=np.infty
    temp_data = tempGraph[:,3].reshape(-1,1)  ## GMM of dist 
    for n_components in n_components_range:
        gmm = mixture.GaussianMixture(n_components = n_components, random_state=2020)
        gmm.fit(temp_data)
        gmm_bic = gmm.bic(temp_data)
        if gmm_bic < lowest_bic:
            best_gmm = gmm
            lowest_bic = gmm_bic
            best_component = n_components  
    
    mIndex = np.where(best_gmm.weights_ == max(best_gmm.weights_))[0]
    cutoff = best_gmm.means_[mIndex] + 2*np.sqrt(best_gmm.covariances_[mIndex])

    for simplex in delauny.simplices:
        simplex.sort()          
        dist0 = distance.euclidean(points[simplex[0],:], points[simplex[1],:])
        dist1 = distance.euclidean(points[simplex[0],:], points[simplex[2],:])
        dist2 = distance.euclidean(points[simplex[1],:], points[simplex[2],:])
        tempArray = np.array((dist0, dist1, dist2))
        badIndex = np.where(tempArray == max(tempArray))[0][0]  ## remove longest edges among simplex taiangle.
        if tempArray[badIndex] > cutoff:
            edge0 = str(simplex[0]) + " " + str(simplex[1])  
            edge1 = str(simplex[0]) + " " + str(simplex[2])       
            edge2 = str(simplex[1]) + " " + str(simplex[2])
            edgeCount = 0
            if edge0 in edges and edge1 in edges and edge2 in edges:
                if badIndex == 0:
                    del edges[edge0]
                elif badIndex == 1:
                    del edges[edge1]
                elif badIndex == 2:
                    del edges[edge2]     ## remove longest edges from edges

    i = 0
    for kk in edges.keys():         ## recrete cellGraph with new edges
        node0 = int(kk.split(sep=" ")[0])
        node1 = int(kk.split(sep=" ")[1])
        edgeDiff = normCount[node0] - normCount[node1]
        energy = np.exp((0 - edgeDiff**2)/(2*var))
        dist = distance.euclidean(points[node0,:], points[node1,:])
        cellGraph[i] = [node0, node1, energy, dist]       
        i = i + 1   
    tempGraph = cellGraph[0:i] # tempGraph: no need remove these longest edges among all edges
    
# 3.  test these longest edges among global distance 
    temp_data = tempGraph[:,3].reshape(-1,1)    
    gmm = mixture.GaussianMixture(n_components = 1,random_state=2020)
    gmm.fit(temp_data)    
    cutoff = gmm.means_[0] + 2*np.sqrt(gmm.covariances_[0])
    tempGraph1 = tempGraph.copy()  
    j=0
    for i in np.arange(tempGraph.shape[0]):    
        if tempGraph[i, 3] < cutoff:     ### 
            tempGraph1[j] = tempGraph[i] # tempGraph2: need remove these longest edges among all edges
            j = j + 1
    
# 4. count broken edges  20211030 by coco
    tempGraph2 = tempGraph1[0:j].copy()
    G_full = nx.Graph()
    G_full.add_nodes_from(set(np.arange(points.shape[0])))
    G_full.add_edges_from(tempGraph2[:, 0:2].astype(np.int32))
    # isolated_nodes = list(nx.isolates(G_full))
    # len(isolated_nodes) 
    deg_dict = {key: value for (key, value) in list(G_full.degree())}
    Nbroken= sum(np.array(list(deg_dict.values()))<2)
    if Nbroken >3:
        finalGraph = tempGraph  # too many broken edges, no need remove these longest edges among all edges
    else:
        finalGraph = tempGraph1  # tempGraph2: need remove these longest edges among all edges
    return finalGraph


def first_neg_index(a):
    '''
    deprecated
    '''
    for i in np.arange(a.shape[0]):
        if a[i] < 0:
            return i
    return a.shape[0] - 1                

def calc_u_cost(a, mid_points):
    '''
    deprecated
    '''
    neg_index = int(a[0])
    x = a[1]
    m_arr = np.concatenate((0 - mid_points[0:neg_index+1], 
                            mid_points[neg_index:]), axis=0)
    x_arr = np.concatenate((np.repeat(x, neg_index+1), 
                0 - np.repeat(x, mid_points.shape[0] - neg_index)), axis=0)
    return m_arr+x_arr   


def compute_pairwise_cost(size, smooth_factor):
    '''
    Returns pairwise energy.
    
    :param points: size: scalar; smooth_factor: scalar

    :rtype: pairwise energy matrix.
    '''
    pairwise_size = size
    pairwise = -smooth_factor * np.eye(pairwise_size, dtype=np.int32)
    step_weight = -smooth_factor*np.arange(pairwise_size)[::-1]
    for i in range(pairwise_size): 
        pairwise[i,:] += np.roll(step_weight,i) 
    temp = np.triu(pairwise).T + np.triu(pairwise)
    np.fill_diagonal(temp, np.diag(temp)/2)
    return temp

def cut_graph_general_profile(cellGraph, count,gmm, unary_scale_factor=100, 
                      smooth_factor=50, label_cost=10, algorithm='expansion'):
    '''
    Returns new labels and gmm for the cut with gmm profile.
    
    :param points: cellGraph (n,3); count: shape (n,); 
    :unary_scale_factor, scalar; smooth_factor, scalar; 
    :label_cost: scalar; algorithm='expansion'
    :rtype: label shape (n,); gmm object.
    '''
    unary_scale_factor = unary_scale_factor
    label_cost = label_cost
    algorithm = algorithm
    smooth_factor = smooth_factor
    gmm=gmm      
    unary_cost = compute_unary_cost_simple_profile(count, gmm, unary_scale_factor)
    
    pairwise_cost = compute_pairwise_cost(gmm.means_.shape[0], smooth_factor)
    edges = cellGraph[:,0:2].astype(np.int32)
    labels = pygco.cut_from_graph(edges, unary_cost, pairwise_cost, label_cost)
#    energy = compute_energy(unary_cost, pairwise_cost, edges, labels)

    return labels

def compute_unary_cost_simple_profile(count, gmm, unary_scale_factor):
    '''
    Returns unary cost energy.
    
    :param points: count: shape (n,); gmm: gmm object; scale_factor: scalar

    :rtype: unary energy matrix.
    '''    
    exp = count
    a= exp[exp > 0]
    gmm_pred = gmm.predict(a.reshape(-1,1))
    zero_label = gmm.predict(np.min(a).reshape(-1,1))[0]
    label_pred = gmm.predict(exp.reshape(-1,1))
    if len(np.where(exp == 0)[0]) > 0:
        np.place(label_pred, exp==0, zero_label)
    uniq, count = np.unique(label_pred, return_counts = True)    

    if len(uniq)<2:  ## when just one uniq,the will mismatch dimension with pairwaise label
        uniq_modify =np.append(uniq,uniq)
    else:
        uniq_modify =uniq
    uninary_mat = np.zeros((len(label_pred), len(uniq_modify)))
    for i in np.arange(uninary_mat.shape[0]):
        for j in np.arange(len(uniq_modify)):
            if uniq_modify[j] == label_pred[i]:  ## same ,energy -1; imsame ,energy 1.
                uninary_mat[i, j] = -1
            else:
                uninary_mat[i, j] = 1   
    return (unary_scale_factor*uninary_mat).astype(np.int32), label_pred


def cut_graph_general(cellGraph, count,gmm, unary_scale_factor=100, 
                      smooth_factor=30, label_cost=10, algorithm='expansion',
                      profile=False):
    '''
    Returns new labels and gmm for the cut.
    
    :param points: cellGraph (n,3); count: shape (n,); 
    :unary_scale_factor, scalar; smooth_factor, scalar; 
    :label_cost: scalar; algorithm='expansion'
    :rtype: label shape (n,); gmm object.
    '''
    unary_scale_factor = unary_scale_factor
    label_cost = label_cost
    algorithm = algorithm
    smooth_factor = smooth_factor
    gmm=gmm
#     a = count.copy() 
#     if sum(a>1)/len(count) <= 0.1 and sum(a>1)<=30:             ###  1-1, 0.25 can modify 
#         np.place(a, a==0, (np.random.rand(sum(a==0))*0.25)) # using 0.1 gives layer 4 300; 0.25 gives 150; 0.5 gives 80?
#         gmm = find_mixture_2(a)       
#     else:
#         a=a[a>0]
#         gmm = find_mixture(a)
    
    if profile==False:
        unary_cost, label_pred = compute_unary_cost_simple(count,  gmm,  unary_scale_factor)
    else:
        unary_cost, label_pred = compute_unary_cost_simple_profile(count, gmm, unary_scale_factor)
    
    pairwise_cost = compute_pairwise_cost(gmm.means_.shape[0], smooth_factor)
    edges = cellGraph[:,0:2].astype(np.int32)
    labels = pygco.cut_from_graph(edges, unary_cost,
                                 pairwise_cost.astype(np.int32), np.int32(label_cost))
#    energy = compute_energy(unary_cost, pairwise_cost, edges, labels)

    return labels, label_pred



def compute_unary_cost_simple(count, gmm, scale_factor):
    '''
    Returns unary cost energy.
    
    :param points: count: shape (n,); gmm: gmm object; scale_factor: scalar

    :rtype: unary energy matrix.
    '''    
    exp = count
    a= exp[exp > 0]

    gmm_pred = gmm.predict(a.reshape(-1,1))
    zero_label = gmm.predict(np.min(a).reshape(-1,1))[0]
    label_pred = gmm.predict(exp.reshape(-1,1))
    if len(np.where(exp == 0)[0]) > 0:
        np.place(label_pred, exp==0, zero_label)
    
    temp_means = np.sort(gmm.means_, axis=None)
   # new_index = np.where(gmm.means_ == temp_means)[1]
    # temp_covs = gmm.covariances_.copy()
    # for i in np.arange(new_index.shape[0]):
    #     temp_covs[i] = gmm.covariances_[new_index[i]]  ##sorted gmm.covariances as np.sort(gmm.means_)
    new_index= np.argsort(gmm.means_,axis=0).reshape(1,-1)[0]
    temp_covs = gmm.covariances_[new_index]

    new_labels = np.zeros(label_pred.shape[0], dtype=np.int32)
    for i in np.arange(new_index.shape[0]):
        temp_index = np.where(label_pred == i)[0]
        new_labels[temp_index] = new_index[i]  ## the label of min(gmm.means_) is 0; label of max is 1.

    mid_points = np.zeros(len(new_index) - 1)
    for i in np.arange(len(mid_points)):
        mid_points[i] = (temp_means[i]*np.sqrt(temp_covs[i+1]) + 
                     temp_means[i+1]*np.sqrt(temp_covs[i])
                    )/(np.sqrt(temp_covs[i]) + np.sqrt(temp_covs[i+1]))
    temp = count[:, np.newaxis] - temp_means.T[1:]
    neg_indices = np.apply_along_axis(first_neg_index, 1, temp)
    ind_count_arr = np.vstack((neg_indices, count)).T        
    unary_cost =  (scale_factor*np.apply_along_axis(calc_u_cost, 1, 
                                    ind_count_arr, mid_points)).astype(np.int32)
    a = count.copy()
    if sum(a>1)/len(count) <= 0.1 and sum(a>1)<=30:
        unique_cost, counts_cost = np.unique(unary_cost, return_counts=True)
        val_to_replace1 = unique_cost[np.argmax(counts_cost)]
        val_to_replace2 = 0 - unique_cost[np.argmax(counts_cost)]
        if val_to_replace1 < 10 and val_to_replace1 > 0:
            np.place(unary_cost, unary_cost == val_to_replace1, np.median(abs(unique_cost)))
        elif val_to_replace1 < 0 and val_to_replace1 > -10:
            np.place(unary_cost, unary_cost == val_to_replace1, 0-np.median(abs(unique_cost)))
        if val_to_replace2 < 10 and val_to_replace2 > 0:
            np.place(unary_cost, unary_cost == val_to_replace2, np.median(abs(unique_cost)))
        elif val_to_replace2 < 0 and val_to_replace2 > -10:
            np.place(unary_cost, unary_cost == val_to_replace2, 0-np.median(abs(unique_cost)))
    return unary_cost, label_pred


def noise_inside(a, b, cellGraph):
    idx0 = np.in1d(cellGraph[:,0], np.array(list(a))).nonzero()[0]
    idx1 = np.in1d(cellGraph[:,1], np.array(list(a))).nonzero()[0]
    neighbor0 = cellGraph[idx0, 1]
    neighbor1 = cellGraph[idx1, 0]
    neighbors = set(neighbor0.tolist() + neighbor1.tolist())   
    out_neighbors = neighbors.difference(set(a))
    not_a_neighbors = out_neighbors.difference(set(b))
    return (len(not_a_neighbors) == 0)


# added noise
def compute_p_CSR(locs, newLabels, gmm, exp, cellGraph): 
    '''
    Returns p_value of the cut.
    
    :param points: newLabels: shape (n,); gmm: gmm object
                   exp: ndarray shape (n ,3); cellGraph: shape (n,3)

    :rtype: p_value.
    '''
    
    p_values = list()
    node_lists = list()
    gmm_pred=gmm_predict(exp,gmm)
    unique, counts = np.unique(gmm_pred,return_counts=True)
    con_components = count_component(locs,cellGraph, newLabels)  ## nodes index in subgraphs
    noise = dict()
    # now calculate p for all comp without considering noise

    min_sig_p_size = np.inf
    for j in np.arange(len(con_components)):
        node_list = con_components[j]
        com_size = len(node_list)
        if len(exp)<= 1000:
            com_cutoff = 3
            addNoise = True
        else:
            com_cutoff = 9
            addNoise = False
        if com_size >= com_cutoff:
                # we want to score the object not the back ground
            gmm_pred_com=gmm_pred[list(node_list)]          
            # check 0s
            unique_com, counts_com = np.unique(gmm_pred_com, return_counts=True)  
            major_label = unique_com[np.where(counts_com == counts_com.max())[0][0]]
            label_count = counts[np.where(unique == major_label)[0]]  ##  real counts that in subgraphs,Ci
            count_in_com =  counts_com.max()   ## counts by graph cuts,k
            cover = exp.shape[0]/com_size
            p0 = poisson.sf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
            p1 = poisson.pmf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
            prob=min((p0+p1),1)  
            if prob < 10e-30:
                prob = 10e-30
            p_values.append(prob)
            if prob <0.1 and len(con_components[j]) < min_sig_p_size:
                min_sig_p_size = len(con_components[j])
        else: # set small comp p=1
            p_values.append(1)            
        node_lists.append(np.array(list(node_list)))
    
    if addNoise:
        for j in np.arange(len(con_components)):
            if p_values[j] >= 0.1 and len(con_components[j]) < 10: # min_sig_p_size:
                noise[j] = con_components[j]
    # now re-calculate p considering noise
        for j in np.arange(len(con_components)):
            if p_values[j] < 0.1: # small p let correct by consider noise
                noise_size = 0
                used_com = list()
                for jj, cc in noise.items():
                    if noise_inside(cc, con_components[j], cellGraph):
                        noise_size = noise_size + len(cc)
                        used_com.append(jj)
                if noise_size > 0:
                    for jj in used_com:
                        noise.pop(jj)
                    node_list = con_components[j]
                    com_size = len(node_list)
                    gmm_pred_com=gmm_pred[list(node_list)]
                    unique_com, counts_com = np.unique(gmm_pred_com, return_counts=True)  
                    major_label = unique_com[np.where(counts_com == counts_com.max())[0][0]]
                    label_count = counts[np.where(unique == major_label)[0]]  ##  real counts that in subgraphs,Ci
                    count_in_com =  counts_com.max()   ## counts by graph cuts,k
                    com_size = com_size + noise_size
                    cover = exp.shape[0]/com_size
                
                    p0 = poisson.sf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
                    p1 = poisson.pmf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
                    prob=min((p0+p1),1)
                    if prob < 10e-30:
                        prob = 10e-30  
                    p_values[j] = prob               
    return p_values, node_lists, con_components


def count_component(locs, cellGraph, newLabels):
    '''
    Returns number of subgraphs.
    
    :param points: cellGraph: shape (n,3); newLabels: ndarray shape (n,); locs: shape (n, 2) 

    :rtype: scalar. 
    
    '''

    G_cut = nx.Graph()
    tempGraph = cellGraph.copy()

    tempGraph = np.apply_along_axis(remove_egdes, 1, tempGraph, newLabels)   ## reassign the edges energy between two nodes,save cellGraph[:,2]
    G_cut.add_nodes_from(list(set(list(tempGraph[:,0].astype(np.int32)) + list(tempGraph[:,1].astype(np.int32)))))    
    G_cut.add_edges_from(tempGraph[np.where(tempGraph[:,2] == 1)[0],0:2].astype(np.int32))  ## connect same label to one subgraph.
    
    com = sorted(nx.connected_components(G_cut),    ## sort by len(# nodes in subgraphs), com[1] is second largest subgraphs
                                  key = len, reverse=True)  

    return com  

def count_isolate(locs, cellGraph, newLabels):
    '''
    Returns number of subgraphs.
    
    :param points: cellGraph: shape (n,3); newLabels: ndarray shape (n,); locs: shape (n, 2) 

    :rtype: scalar. 
    
    '''

    G_full = nx.Graph()
    G_cut = nx.Graph()
    tempGraph = cellGraph.copy()
    G_full.add_nodes_from(list(set(list(tempGraph[:,0].astype(np.int32)) + list(tempGraph[:,1].astype(np.int32)))))
    G_full.add_edges_from(tempGraph[:, 0:2].astype(np.int32))   
    tempGraph = np.apply_along_axis(remove_egdes, 1, tempGraph, newLabels)
    G_cut.add_nodes_from(list(set(list(tempGraph[:,0].astype(np.int32)) + list(tempGraph[:,1].astype(np.int32)))))    
    G_cut.add_edges_from(tempGraph[np.where(tempGraph[:,2] == 1)[0],0:2].astype(np.int32))
    degree_cutoff = int(min(nx.average_degree_connectivity(G_full).values()))
    com = sorted(nx.connected_components(G_cut), 
                                  key = len, reverse=True)  
    isolated_nodes = list(nx.isolates(G_cut))
    seg_size = np.zeros(18)
    if len(isolated_nodes) > 0:
        deg_dict = {key: value for (key, value) in list(G_full.degree(list(nx.isolates(G_cut))))}
#    seg_size[0] = (np.array(list(deg_dict.values())) > degree_cutoff).sum()             
#    t_com = len(com) + locs.shape[0] - sum_nodes 
#    t_com = locs.shape[0] - sum_nodes 
        seg_size[0] = (np.array(list(deg_dict.values())) <= degree_cutoff).sum() 
        seg_size[1] = (np.array(list(deg_dict.values())) > degree_cutoff).sum() 
    for cc in com:
        if len(cc) == 2:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() > 0:
                seg_size[2] = seg_size[2] + 1
            else:
                seg_size[3] = seg_size[3] + 1
                
        if len(cc) == 3:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[4] = seg_size[4] + 1
            else:
                seg_size[5] = seg_size[5] + 1  
    
        if len(cc) == 4:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[6] = seg_size[6] + 1
            else:
                seg_size[7] = seg_size[7] + 1 
        if len(cc) == 5:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[8] = seg_size[8] + 1
            else:
                seg_size[9] = seg_size[9] + 1 
        if len(cc) == 6:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[10] = seg_size[10] + 1
            else:
                seg_size[11] = seg_size[11] + 1 
        if len(cc) == 7:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[12] = seg_size[12] + 1
            else:
                seg_size[13] = seg_size[13] + 1 
        if len(cc) == 8:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[14] = seg_size[14] + 1
            else:
                seg_size[15] = seg_size[15] + 1 
        if len(cc) == 9:
            deg_dict = {key: value for (key, value) in list(G_full.degree(list(cc)))}
            if (np.array(list(deg_dict.values())) <= degree_cutoff).sum() >= 2:
                seg_size[16] = seg_size[16] + 1
            else:
                seg_size[17] = seg_size[17] + 1 
        
    return seg_size #(np.array(list(deg_dict.values())) > degree_cutoff).sum()


def remove_egdes(edges, newLabels):
    '''
    Mark boundary of the cut.
    
    :param points: edges: shape (n,); newLabels: shape(k,)

    :rtype: marked edges.
    '''
    if newLabels[int(edges[0])] != newLabels[int(edges[1])]:   
        edges[2] = 0
    else:
        edges[2] = 1
    return edges



def gmm_predict(exp,gmm):
    """Predict the labels for the data samples in X using trained model and replace the zreo using the minimum data label.

    Parameters
    ----------
    exp : array-like, shape (n_samples, n_features)
        List of n_features-dimensional data points. Each row
        corresponds to a single data point.

    gmm : Gussian Mixture model

    Returns
    -------
    labels : array, shape (n_samples,)
        Component labels.
    """
    a=exp[exp>0]
    gmm_pred=gmm.predict(a.reshape(-1,1))
    zero_label=gmm.predict(np.min(a).reshape(-1,1))[0]
    label_pred = gmm.predict(exp.reshape(-1,1))
    if len(np.where(exp == 0)[0]) > 0:
        np.place(label_pred, exp==0, zero_label)
    gmm_pred =label_pred
    return gmm_pred



## 20200629 # new 2021112
def compute_spatial_genomewise_optimize(locs, data_norm, cellGraph, gmmDict, w, n, smooth_factor=60, 
                                         unary_scale_factor=100, label_cost=10, algorithm='expansion'):
    
    
    genes = list()
    nodes = list()
    p_values = list()
    smooth_factors = list()
    pred_labels = list()
    model_list = list()
    model_labels = list()
   
    for geneID in data_norm.columns:
#         print(geneID)
        use_otsu = False
        exp = data_norm.loc[:,geneID].values
        gmm=gmmDict[geneID]
        mi=Moran(exp[n],w,permutations=0)
        zero_part_p=calc_zero_part_p(locs,cellGraph,exp)
        zero_inflated=np.sum(exp==0)

        if mi.I < 0.05:
            start_sf=min(smooth_factor*2,100)
        
        elif mi.I < 0.1:
            start_sf=smooth_factor
        else:
            start_sf=10
            ## decide to use otsu or gmm
            if zero_inflated > len(exp)/5:
                newLabels, thresholds,label_pred = cut_graph_general_otsu(cellGraph, exp, unary_scale_factor, 
                                start_sf, label_cost, algorithm)

                p, node, com = compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph)
                num_isolate = count_isolate(locs,cellGraph, newLabels) 
                noise_size = sum(num_isolate)
                noise_size_norm = noise_size*(200/len(exp))
                logP = (0-np.log10(min(p)))
                obj_otsu= logP-noise_size_norm

                newLabels,_ = cut_graph_general(cellGraph, exp, gmm, unary_scale_factor, 
                                   start_sf, label_cost, algorithm)
                p, node, com = compute_p_CSR(locs, newLabels, gmm, exp, cellGraph)
                num_isolate = count_isolate(locs,cellGraph, newLabels) 
                noise_size = sum(num_isolate)
                noise_size_norm = noise_size*(200/len(exp))
                logP = (0-np.log10(min(p)))
                obj_gmm= logP-noise_size_norm

                if obj_otsu - obj_gmm > 8 :
                    use_otsu = True


        if (zero_part_p < 1e-4 and zero_inflated > len(exp)/3 and len(exp)<1000) or use_otsu:
            model = 'otsu'
            p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best =compute_single_gene_otsu(
                                        locs, exp, cellGraph, smooth_factor=start_sf, 
                                         unary_scale_factor=100, label_cost=10, algorithm='expansion')
        else:
            model = 'gmm'
            p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best =compute_single_gene_gmm(
                                        locs, exp, cellGraph, gmm, smooth_factor=start_sf, 
                                         unary_scale_factor=100, label_cost=10, algorithm='expansion')

            
        if len(p_best)>0:
            final_factor = temp_factor_best
            p_values.append(p_best)
            nodes.append(node_best)
            genes.append(geneID)
            smooth_factors.append(final_factor)
            pred_labels.append(newLabels_best)
            model_list.append(model)
            model_labels.append(list(label_pred_best))
            
    return nodes, p_values, genes, smooth_factors, pred_labels,model_list, model_labels


def identify_spatial_genes(locs, data_norm, cellGraph, gmmDict, smooth_factor=60,
                      unary_scale_factor=100, label_cost=10, algorithm='expansion',
                      ncores = None):

#    pool = mp.Pool()
    '''
    main function to identify spatially variable genes
    :param file:locs: spatial coordinates (n, 2); data_norm: normalized gene expression;
        smooth_factor=10; unary_scale_factor=100; label_cost=10; algorithm='expansion' 
    :rtype: prediction: a dataframe
    '''    
    # 20211119 coco
    G_full = nx.Graph()
    G_cut = nx.Graph()
    tempGraph = cellGraph.copy()
    G_full.add_nodes_from(list(set(list(tempGraph[:,0].astype(np.int32)) + list(tempGraph[:,1].astype(np.int32)))))
    G_full.add_edges_from(tempGraph[:, 0:2].astype(np.int32)) 
    w=pysal.lib.weights.W.from_networkx(G_full) 
    n=list(w.weights.keys())
    
    all_cores = mp.cpu_count()
    if ncores !=None:
        num_cores = ncores
    else:
        num_cores = int(all_cores*0.5)
    if num_cores > math.floor(data_norm.shape[1]/2):
         num_cores=int(math.floor(data_norm.shape[1]/2))
    print(f'scGCO used {num_cores} out of {all_cores} cores')
    ttt = np.array_split(data_norm,num_cores,axis=1)
    tuples = [(l, d, c, g,ww,nn, s, u, b, a) for l, d, c, g,ww,nn, s, u, b, a in zip(
                                    repeat(locs, num_cores), 
                                    ttt,
                                    repeat(cellGraph, num_cores),
                                    repeat(gmmDict, num_cores),
                                    repeat(w, num_cores),
                                    repeat(n, num_cores),
                                    repeat(smooth_factor, num_cores),
                                    repeat(unary_scale_factor, num_cores), 
                                    repeat(label_cost, num_cores),
                                    repeat(algorithm, num_cores))] 
    
    results = parmap.starmap(compute_spatial_genomewise_optimize, tuples,
                                pm_processes=num_cores, pm_pbar=True)
    
#    pool.close()
# p_values, genes, diff_p_values, exp_diff, smooth_factors, pred_labels, model_results
    nnn = [results[i][0] for i in np.arange(len(results))]
    nodes = reduce(operator.add, nnn)
    ppp = [results[i][1] for i in np.arange(len(results))]
    p_values=reduce(operator.add, ppp)
    ggg = [results[i][2] for i in np.arange(len(results))]
    genes = reduce(operator.add, ggg)
    # exp_ppp = [results[i][3] for i in np.arange(len(results))]
    # exp_pvalues = reduce(operator.add, exp_ppp)  
    # exp_ddd = [results[i][4] for i in np.arange(len(results))]
    # exp_diffs = reduce(operator.add, exp_ddd)      
    fff = [results[i][3] for i in np.arange(len(results))]
    s_factors = reduce(operator.add, fff)
    lll = [results[i][4] for i in np.arange(len(results))]
    pred_labels = reduce(operator.add, lll)
    ml = [results[i][5] for i in np.arange(len(results))]
    model_list = reduce(operator.add, ml)
    mmm = [results[i][6] for i in np.arange(len(results))]
    model_labels = reduce(operator.add, mmm)

    best_p_values=[min(i) for i in p_values]
    fdr = multi.multipletests(np.array(best_p_values), method='fdr_bh')[1]
    #exp_fdr = multi.multipletests(np.array(exp_pvalues), method='fdr_bh')[1]    
    
    labels_array = np.array(pred_labels).reshape(len(genes), pred_labels[0].shape[0])
    data_array = np.array((genes, p_values, fdr, s_factors, nodes, model_list , model_labels), dtype=object).T
    t_array = np.hstack((data_array, labels_array))
    c_labels = ['p_value', 'fdr',  'smooth_factor', 'nodes','model','model_labels']
    for i in np.arange(labels_array.shape[1]) + 1:
        temp_label = 'label_cell_' + str(i)
        c_labels.append(temp_label)
    result_df = pd.DataFrame(t_array[:,1:], index=t_array[:,0], 
                      columns=c_labels)
    
    return result_df

def compute_single_gene_gmm(locs, exp, cellGraph, gmm, smooth_factor=10, 
                           unary_scale_factor=100, label_cost=10, algorithm='expansion'):
        
    size_factor=200
    noise_size_estimate= 9
    start_factor=smooth_factor
    ## 0. init cuts graph
    temp_factor = start_factor
    newLabels, label_pred = cut_graph_general(cellGraph, exp, gmm, unary_scale_factor, 
                                        temp_factor, label_cost, algorithm)
    p, node, com = compute_p_CSR(locs, newLabels, gmm, exp, cellGraph)
    num_isolate = count_isolate(locs,cellGraph, newLabels) 

    # noise_size_inside = sum([num_isolate[num] for num in np.arange(1,2*noise_size_estimate, 2)]) # 0629, 16:16
    # noise_size_border = sum([num_isolate[num] for num in np.arange(0,2*noise_size_estimate, 2)])
    # noise_size= noise_size_inside + noise_size_border
    noise_size = sum(num_isolate)

    noise_size_norm = noise_size*(size_factor/len(exp))
    
    logP = (0-np.log10(min(p))) #changed by wanwan

    obj_val = logP- noise_size_norm
            
    # 1. cuts too many noise
    sf_list=[temp_factor]
    obj_val_list=[obj_val]
    noise_size_norm_list=[noise_size_norm]

    while  (temp_factor < 30) or (obj_val < 0 and len(p)>0 and min(p) < 0.01):
        if noise_size_norm >= 20:
            incre = 20
        else:
            incre = 10
        temp_factor = temp_factor + incre   ## can speed up with +10
        newLabels, label_pred = cut_graph_general(cellGraph, exp, gmm, unary_scale_factor, 
                                        temp_factor, label_cost, algorithm)
        p, node, com = compute_p_CSR(locs, newLabels, gmm, exp, cellGraph)  
        num_isolate = count_isolate(locs,cellGraph, newLabels)  

        # noise_size_inside = sum([num_isolate[num] for num in np.arange(1,2*noise_size_estimate, 2)])
        # noise_size_border = sum([num_isolate[num] for num in np.arange(0,2*noise_size_estimate, 2)])
        # noise_size= noise_size_inside + noise_size_border
        noise_size = sum(num_isolate)
        noise_size_norm = noise_size*(size_factor/len(exp))

        logP = (0-np.log10(min(p))) #changed by wanwan

        obj_val = logP- noise_size_norm
        
        sf_list.append(temp_factor)
        obj_val_list.append(obj_val)
        noise_size_norm_list.append(noise_size_norm)
        if temp_factor > 120:
            break
    
    if len(obj_val_list)>=2:
        temp_factor = sf_list[np.argmax(obj_val_list)]
        
        newLabels, label_pred = cut_graph_general(cellGraph, exp, gmm, unary_scale_factor, 
                                        temp_factor, label_cost, algorithm)
        p, node, com = compute_p_CSR(locs, newLabels, gmm, exp, cellGraph) 
        obj_val = obj_val_list[np.argmax(obj_val_list)]
        noise_size_norm = noise_size_norm_list[np.argmax(obj_val_list)]

    
    p_best = p
    node_best = node
    newLabels_best = newLabels
    temp_factor_best = temp_factor  
    obj_val_best = obj_val
    noise_size_best = noise_size_norm
    com_best=com
    label_pred_best = label_pred
    
    
    ## 4. For small significate p_value pattern, need bigger sf
    if len(p_best)>0 and len(node_best[np.argmin(p_best)]) < noise_size_estimate:
        p = p_best
        node = node_best
        newLabels = newLabels_best
        temp_factor = temp_factor_best      
        obj_val = obj_val_best
        noise_size_norm = noise_size_best
        com =com_best
        label_pred = label_pred_best

        while min(p) < 0.01: 
            temp_factor = temp_factor + 50   
            newLabels, label_pred = cut_graph_general(cellGraph, exp, gmm, unary_scale_factor, 
                                        temp_factor, label_cost, algorithm)
            p, node, com = compute_p_CSR(locs, newLabels, gmm, exp, cellGraph)  
            if temp_factor > 500:
                break

        p_best = p
        node_best = node
        newLabels_best = newLabels
        temp_factor_best = temp_factor      
        com_best=com
        label_pred_best = label_pred
    
    return p_best,node_best,newLabels_best,temp_factor_best, com_best, label_pred_best


def compute_single_gene_otsu(locs, exp, cellGraph, smooth_factor=10, 
                            unary_scale_factor=100, label_cost=10, algorithm='expansion'):
    size_factor=200
    noise_size_estimate= 9
    start_factor=smooth_factor
    ## 0. init cuts graph
    temp_factor = start_factor
        
    ## 0. init cuts graph
    newLabels, thresholds, label_pred = cut_graph_general_otsu(cellGraph, exp, unary_scale_factor, 
                                temp_factor, label_cost, algorithm)

    p, node, com = compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph)
    num_isolate = count_isolate(locs,cellGraph, newLabels) 

    # noise_size_inside = sum([num_isolate[num] for num in np.arange(1,2*noise_size_estimate, 2)])
    # noise_size_border = sum([num_isolate[num] for num in np.arange(0,2*noise_size_estimate, 2)])
    # noise_size= noise_size_inside + noise_size_border
    noise_size = sum(num_isolate)
    noise_size_norm = noise_size*(size_factor/len(exp))  
    
    logP = (0-np.log10(min(p))) #changed by wanwan

    obj_val = logP- noise_size_norm

    sf_list=[temp_factor]
    obj_val_list=[obj_val]
    noise_size_norm_list=[noise_size_norm]

    while  (temp_factor < 30) or (obj_val < 1 and len(p)>0 and min(p) < 0.01):
        if noise_size_norm >= 20:
            incre = 20
        else:
            incre = 10
        temp_factor = temp_factor + incre   ## can speed up with +10
        newLabels, thresholds,label_pred = cut_graph_general_otsu(cellGraph, exp, unary_scale_factor, 
                                temp_factor, label_cost, algorithm)

        p, node, com = compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph)
        num_isolate = count_isolate(locs,cellGraph, newLabels)  

        # noise_size_inside = sum([num_isolate[num] for num in np.arange(1,2*noise_size_estimate, 2)])
        # noise_size_border = sum([num_isolate[num] for num in np.arange(0,2*noise_size_estimate, 2)])
        # noise_size= noise_size_inside + noise_size_border
        noise_size = sum(num_isolate)
        noise_size_norm = noise_size*(size_factor/len(exp))

        logP = (0-np.log10(min(p))) #changed by wanwan

        obj_val = logP- noise_size_norm
        
        sf_list.append(temp_factor)
        obj_val_list.append(obj_val)
        noise_size_norm_list.append(noise_size_norm)
        if temp_factor > 120:
            break
    
    if len(obj_val_list)>=2:
        temp_factor = sf_list[np.argmax(obj_val_list)]
        
        newLabels, thresholds,label_pred = cut_graph_general_otsu(cellGraph, exp, unary_scale_factor, 
                                temp_factor, label_cost, algorithm)

        p, node, com = compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph)
        obj_val = obj_val_list[np.argmax(obj_val_list)]
        noise_size_norm = noise_size_norm_list[np.argmax(obj_val_list)]

    p_best = p
    node_best = node
    newLabels_best = newLabels
    temp_factor_best = temp_factor  
    obj_val_best = obj_val
    noise_size_best = noise_size_norm
    com_best=com
    label_pred_best = label_pred
    
    ## 4. For small significate p_value pattern, need bigger sf
    if len(p_best)>0 and len(node_best[np.argmin(p_best)]) < noise_size_estimate:
        p = p_best
        node = node_best
        newLabels = newLabels_best
        temp_factor = temp_factor_best      
        obj_val = obj_val_best
        noise_size_norm = noise_size_best
        com =com_best
        label_pred = label_pred_best

        while min(p) < 0.01:
            temp_factor = temp_factor + 50   
            newLabels, thresholds, label_pred = cut_graph_general_otsu(cellGraph, exp, unary_scale_factor, 
                                temp_factor, label_cost, algorithm)

            p, node, com = compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph)
            if temp_factor > 500:
                break
            
        p_best = p
        node_best = node
        newLabels_best = newLabels
        temp_factor_best = temp_factor      
        com_best=com
        label_pred_best = label_pred
    
    return p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best


def identify_spatial_genes_fixed_sf(locs, data_norm, cellGraph, gmmDict, smooth_factor=30,
                      unary_scale_factor=100, label_cost=10, algorithm='expansion',
                      model = 'gmm'):

#    pool = mp.Pool()
    '''
    main function to identify spatially variable genes
    :param file:locs: spatial coordinates (n, 2); data_norm: normalized gene expression;
        smooth_factor=10; unary_scale_factor=100; label_cost=10; algorithm='expansion' 
    :rtype: prediction: a dataframe
    '''    
    
    num_cores = int(mp.cpu_count()/2)
    if num_cores > math.floor(data_norm.shape[1]/2):
         num_cores=int(math.floor(data_norm.shape[1]/2))
    ttt = np.array_split(data_norm,num_cores,axis=1)
    tuples = [(l, d, c, g, s, u, b, a, m) for l, d, c, g, s, u, b, a, m in zip(
                                    repeat(locs, num_cores), 
                                    ttt,
                                    repeat(cellGraph, num_cores),
                                    repeat(gmmDict, num_cores),
                                    repeat(smooth_factor, num_cores),
                                    repeat(unary_scale_factor, num_cores), 
                                    repeat(label_cost, num_cores),
                                    repeat(algorithm, num_cores),
                                    repeat(model, num_cores))] 
    
    results = parmap.starmap(compute_single_fixed_sf, tuples,
                                pm_processes=num_cores, pm_pbar=True)
    
#    pool.close()
# p_values, genes, diff_p_values, exp_diff, smooth_factors, pred_labels, model_results
    nnn = [results[i][0] for i in np.arange(len(results))]
    nodes = reduce(operator.add, nnn)
    ppp = [results[i][1] for i in np.arange(len(results))]
    p_values=reduce(operator.add, ppp)
    ggg = [results[i][2] for i in np.arange(len(results))]
    genes = reduce(operator.add, ggg)
    # exp_ppp = [results[i][3] for i in np.arange(len(results))]
    # exp_pvalues = reduce(operator.add, exp_ppp)  
    # exp_ddd = [results[i][4] for i in np.arange(len(results))]
    # exp_diffs = reduce(operator.add, exp_ddd)      
    fff = [results[i][3] for i in np.arange(len(results))]
    s_factors = reduce(operator.add, fff)
    lll = [results[i][4] for i in np.arange(len(results))]
    pred_labels = reduce(operator.add, lll)
    ml = [results[i][5] for i in np.arange(len(results))]
    model_list = reduce(operator.add, ml)
    mmm = [results[i][6] for i in np.arange(len(results))]
    model_labels = reduce(operator.add, mmm)

    best_p_values=[min(i) for i in p_values]
    fdr = multi.multipletests(np.array(best_p_values), method='fdr_bh')[1]
    
    labels_array = np.array(pred_labels).reshape(len(genes), pred_labels[0].shape[0])
    data_array = np.array((genes, p_values, fdr, s_factors, nodes, model_list , model_labels), dtype=object).T
    t_array = np.hstack((data_array, labels_array))
    c_labels = ['p_value', 'fdr',  'smooth_factor', 'nodes','model','model_labels']
    for i in np.arange(labels_array.shape[1]) + 1:
        temp_label = 'label_cell_' + str(i)
        c_labels.append(temp_label)
    result_df = pd.DataFrame(t_array[:,1:], index=t_array[:,0], 
                      columns=c_labels)
    
    return result_df

def compute_single_fixed_sf(locs, data_norm, cellGraph,gmmDict, smooth_factor=30, 
                        unary_scale_factor=100,label_cost = 10, algorithm ='expansion', 
                        model =  'gmm',**kw):
    '''
    default model is GMM, parameters
    '''
    
    genes = list()
    nodes = list()
    p_values = list()
    smooth_factors = list()
    pred_labels = list()
    model_list = list()
    model_labels = list()
    smooth_factor = smooth_factor
    for geneID in  data_norm.columns:
        exp = data_norm.loc[:,geneID].values
       
        if model == 'otsu':
            p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best = compute_single_otsu_fixed_sf(
                                        locs,exp,cellGraph,smooth_factor, 
                                        unary_scale_factor, label_cost, algorithm)
        else:
            model = 'gmm' 
            gmm = gmmDict[geneID]
          
            p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best = compute_single_gmm_fixed_sf(
                    locs,exp,cellGraph,gmm, smooth_factor, 
                    unary_scale_factor, label_cost, algorithm)

        final_factor = smooth_factor
        p_values.append(p_best)
        nodes.append(node_best)
        genes.append(geneID)
        smooth_factors.append(final_factor)
        pred_labels.append(newLabels_best)
        model_list.append(model)
        model_labels.append(list(label_pred_best))
            
    return nodes, p_values, genes, smooth_factors, pred_labels, model_list, model_labels


def compute_single_gmm_fixed_sf(locs,exp,cellGraph,gmm, smooth_factor=30, 
                    unary_scale_factor=100, label_cost=10, algorithm='expansion'):
    
    newLabels, label_pred = cut_graph_general(cellGraph, exp, gmm, unary_scale_factor, 
                                        smooth_factor, label_cost, algorithm)
    p, node, com = compute_p_CSR(locs, newLabels, gmm, exp, cellGraph) 
    
    p_best = p
    node_best = node
    newLabels_best = newLabels
    temp_factor_best = smooth_factor
    com_best=com
    label_pred_best = label_pred

    return p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best


def compute_single_otsu_fixed_sf(locs,exp,cellGraph,smooth_factor=30, 
                          unary_scale_factor=100, label_cost=10, algorithm='expansion'):
    
    newLabels, thresholds,label_pred = cut_graph_general_otsu(cellGraph, exp, unary_scale_factor, 
                                smooth_factor, label_cost, algorithm)

    p, node, com = compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph)
    
    p_best = p
    node_best = node
    newLabels_best = newLabels
    temp_factor_best = smooth_factor
    com_best=com
    label_pred_best = label_pred

    return p_best,node_best,newLabels_best,temp_factor_best,com_best, label_pred_best


# implement this one
def compute_p_CSR_otsu(locs, newLabels , label_pred, exp, cellGraph): 
    '''
    Returns p_value of the cut with otsu.
    
    :param points: newLabels: shape (n,); gmm: gmm object
                   exp: ndarray shape (n ,3); cellGraph: shape (n,3)

    :rtype: p_value.
    '''
    
    p_values = list()
    node_lists = list()
    label_pred = label_pred
    noise_size_estimate=9 
    unique, counts = np.unique(label_pred,return_counts=True)
    con_components = count_component(locs,cellGraph, newLabels) 
    noise = dict()
    # now calculate p for all comp without considering noise
    min_sig_p_size = np.inf
    for j in np.arange(len(con_components)):
        node_list = con_components[j]
        com_size = len(node_list)
        if len(exp)<= 1000:
            com_cutoff = 3
            addNoise = True
        else:
            com_cutoff = 9
            addNoise = False
        if com_size >= com_cutoff:
            temp_exp = exp[np.array(list(node_list))]
            label_pred_com = label_pred[np.array(list(node_list))]
        # check 0s
            unique_com, counts_com = np.unique(label_pred_com, return_counts=True)  
            major_label = unique_com[np.where(counts_com == counts_com.max())[0][0]]
            label_count = counts[np.where(unique == major_label)[0]]  ##  real counts that in subgraphs,Ci
            count_in_com =  counts_com.max()   ## counts by graph cuts,k
            cover = exp.shape[0]/com_size
            prob_cdf = (1 - poisson.cdf(count_in_com, com_size*(label_count/exp.shape[0]))[0]) #*cover
            
            psf = poisson.sf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
            pmf = poisson.pmf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
            prob=min((psf+pmf),1)  
            if prob < 10e-30:
                prob = 10e-30
            p_values.append(prob)
            if prob <0.1 and len(con_components[j]) < min_sig_p_size:
                min_sig_p_size = len(con_components[j])
        else: # set small comp p=1
            p_values.append(1)            
        node_lists.append(np.array(list(node_list)))
#        print(p_values)
    if addNoise:
        for j in np.arange(len(con_components)):
            if p_values[j] >= 0.1 and len(con_components[j]) < 10: # min_sig_p_size:
                noise[j] = con_components[j]
        # now re-calculate p considering noise

        for j in np.arange(len(con_components)):
            if p_values[j] < 0.1: # small p let correct by consider noise
                noise_size = 0
                used_com = list()
                for jj, cc in noise.items():
                    if noise_inside(cc, con_components[j], cellGraph):
                        noise_size = noise_size + len(cc)
                        used_com.append(jj)
                if noise_size > 0:
                    for jj in used_com:
                        noise.pop(jj)
                    node_list = con_components[j]
                    com_size = len(node_list)
                    temp_exp = exp[np.array(list(node_list))]
                    label_pred_com = label_pred[np.array(list(node_list))]
                    unique_com, counts_com = np.unique(label_pred_com, return_counts=True)  
                    major_label = unique_com[np.where(counts_com == counts_com.max())[0][0]]
                    label_count = counts[np.where(unique == major_label)[0]]  ##  real counts that in subgraphs,Ci
                    count_in_com =  counts_com.max()   ## counts by graph cuts,k
                    com_size = com_size + noise_size
                    cover = exp.shape[0]/com_size
                    prob_cdf = (1 - poisson.cdf(count_in_com, com_size*(label_count/exp.shape[0]))[0]) #*cover
                
                    psf = poisson.sf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
                    pmf = poisson.pmf(count_in_com, com_size*(label_count/exp.shape[0]))[0]
                    prob = min((psf + pmf),1)   
                    if prob < 10e-30:
                        prob = 10e-30
                    p_values[j] = prob      
                    
    return p_values, node_lists, con_components

## try modify 20200602 
def compute_unary_cost_simple_otsu(count, thresholds, unary_scale_factor = 100):
    '''
    Returns unary cost energy.
    
    :param points: count: shape (n,); gmm: gmm object; scale_factor: scalar

    :rtype: unary energy matrix.
    '''    
   
    labels_pred = np.zeros(count.shape[0])
    if isinstance(thresholds, np.float32):
        labels_pred[np.where(count > thresholds)] = 1
        mid_points = np.zeros(1)
        temp_means = np.zeros(2)
        temp_covs = np.zeros(2)
    else:
        for i in np.arange(thresholds.shape[0]):
            labels_pred[np.where(count > thresholds[i])] = i + 1      
            mid_points = np.zeros(thresholds.shape[0])
            temp_means = np.zeros(thresholds.shape[0] + 1)
            temp_covs = np.zeros(thresholds.shape[0]+1)

    for i in np.arange(len(temp_means)):
        temp_means[i] = np.mean(count[labels_pred == i])
        temp_covs[i] = np.var(count[labels_pred==i], ddof=1)

    for i in np.arange(len(mid_points)):
        #  mid_points[i] = np.max(count[labels_pred == i])
         mid_points[i] = (temp_means[i]*np.sqrt(temp_covs[i+1]) + 
                     temp_means[i+1]*np.sqrt(temp_covs[i])
                    )/(np.sqrt(temp_covs[i]) + np.sqrt(temp_covs[i+1]))
    temp = count[:, np.newaxis] - temp_means.T[1:]
    neg_indices = np.apply_along_axis(first_neg_index, 1, temp)
    ind_count_arr = np.vstack((neg_indices, count)).T        
    unary_cost =  (unary_scale_factor*np.apply_along_axis(calc_u_cost, 1, 
                                    ind_count_arr, mid_points)).astype(np.int32)  

    a = count.copy()
    if sum(a>1)/len(count) <= 0.1 and sum(a>1)<=30:
        unique_cost, counts_cost = np.unique(unary_cost, return_counts=True)
        val_to_replace1 = unique_cost[np.argmax(counts_cost)]
        val_to_replace2 = 0 - unique_cost[np.argmax(counts_cost)]
        if val_to_replace1 < 10 and val_to_replace1 > 0:
            np.place(unary_cost, unary_cost == val_to_replace1, np.median(abs(unique_cost)))
        elif val_to_replace1 < 0 and val_to_replace1 > -10:
            np.place(unary_cost, unary_cost == val_to_replace1, 0-np.median(abs(unique_cost)))
        if val_to_replace2 < 10 and val_to_replace2 > 0:
            np.place(unary_cost, unary_cost == val_to_replace2, np.median(abs(unique_cost)))
        elif val_to_replace2 < 0 and val_to_replace2 > -10:
            np.place(unary_cost, unary_cost == val_to_replace2, 0-np.median(abs(unique_cost)))       
    return labels_pred.astype(np.int32), unary_cost

## wang
def compute_unary_cost_simple_otsu_mid(count, thresholds, unary_scale_factor = 100):
    '''
    depprecated

    Returns unary cost energy.
    
    :param points: count: shape (n,); gmm: gmm object; scale_factor: scalar

    :rtype: unary energy matrix.
    '''    
   
    labels_pred = np.zeros(count.shape[0])
    if isinstance(thresholds, np.float32):
        labels_pred[np.where(count > thresholds)] = 1
        mid_points = np.zeros(1)
        temp_means = np.zeros(2)
    else:
        for i in np.arange(thresholds.shape[0]):
            labels_pred[np.where(count > thresholds[i])] = i + 1      
    
            mid_points = np.zeros(thresholds.shape[0])
            temp_means = np.zeros(thresholds.shape[0] + 1)
   

    for i in np.arange(len(temp_means)):
        temp_means[i] = np.mean(count[labels_pred == i])
 
    for i in np.arange(len(mid_points)):
         mid_points[i] = np.max(count[labels_pred == i])
    
    temp = count[:, np.newaxis] - mid_points.T 
    neg_indices = np.apply_along_axis(first_neg_index, 1, temp)
    ind_count_arr = np.vstack((neg_indices, count)).T               
    return labels_pred.astype(np.int32), (unary_scale_factor*np.apply_along_axis(calc_u_cost, 1, 
                                    ind_count_arr, mid_points)).astype(np.int32)


def compute_unary_cost_simple_profile_otsu(count, thresholds, unary_scale_factor):
    '''
    Returns unary cost energy just use profile without exp.
    
    :param points: count: shape (n,); gmm: gmm object; scale_factor: scalar

    :rtype: unary energy matrix.
    '''    
    labels_pred = np.zeros(count.shape[0])
    if isinstance(thresholds, np.float32):
        labels_pred[np.where(count > thresholds)] = 1       
    else:
        for i in np.arange(thresholds.shape[0]):
            labels_pred[np.where(count > thresholds[i])] = i + 1

    uniq, count = np.unique(labels_pred, return_counts = True)    

    unary_mat = np.zeros((len(labels_pred), len(uniq)))
    for i in np.arange(unary_mat.shape[0]):
        for j in np.arange(len(uniq)):
            if uniq[j] == labels_pred[i]:  ## same ,energy -1; imsame ,energy 1.
                unary_mat[i, j] = -1
            else:
                unary_mat[i, j] = 1   
    return labels_pred.astype(np.int32), (unary_scale_factor*unary_mat).astype(np.int32)


def cut_graph_general_otsu(cellGraph, exp, unary_scale_factor=100, 
                      smooth_factor=10, label_cost=10, algorithm='expansion',
                      profile=False):
    '''
    Returns new labels and gmm for the cut.
    
    :param points: cellGraph (n,3); count: shape (n,); 
    :unary_scale_factor, scalar; smooth_factor, scalar; 
    :label_cost: scalar; algorithm='expansion'
    :rtype: label shape (n,); gmm object.
    '''
    count=exp
#    thresholds = threshold_multiotsu(exp) #3D
    thresholds = threshold_otsu(count)  #2D
    if isinstance(thresholds, np.float32):
        pair_size = 2
    else:
        pair_size = thresholds.shape[0] + 1
    smooth_factor = smooth_factor
    unary_scale_factor = unary_scale_factor
    label_cost = label_cost
    algorithm = algorithm
    if profile==False:
        label_pred, unary_cost = compute_unary_cost_simple_otsu(count, thresholds, unary_scale_factor)
    else:
        label_pred,unary_cost= compute_unary_cost_simple_profile_otsu(count, thresholds, unary_scale_factor)
    pairwise_cost = compute_pairwise_cost(pair_size, smooth_factor)
    edges = cellGraph[:,0:2].astype(np.int32)
    labels = pygco.cut_from_graph(edges, unary_cost, pairwise_cost, label_cost)
#    energy = compute_energy(unary_cost, pairwise_cost, edges, labels)

    return labels, thresholds ,label_pred


def calc_zero_part_p(locs, cellGraph, exp):
    temp_exp = exp*1000
    temp_newLabels = temp_exp.astype(int)

    con_components = count_component(locs,cellGraph, temp_newLabels) 

    noise = dict()
    # now calculate p for all comp without considering noise
    if True:
        for j in np.arange(len(con_components)):
            if j > 0:
                break
            node_list = con_components[j]
            com_size = len(node_list)
            label_count = len(np.where(temp_newLabels == 0)[0])
            count_in_com = com_size
            prob_cdf = 1 - poisson.cdf(count_in_com, com_size*(label_count/exp.shape[0]))
            cover = exp.shape[0]/com_size

            if label_count==0:
                prob_sf=1
            else:
                psf = poisson.sf(count_in_com, com_size*(label_count/exp.shape[0]))
                pmf = poisson.pmf(count_in_com, com_size*(label_count/exp.shape[0]))
                prob_sf = (psf + pmf)
            #print(count_in_com, com_size, label_count, exp.shape[0])
    return min(prob_sf*cover,1)

