import underworld as uw
import underworld.function as fn
import numpy as np
from mpi4py import MPI
from .LecodeIsostasy import LecodeIsostasy
from .scaling import nonDimensionalize as nd
from .scaling import UnitRegistry as u
from ._utils import Balanced_InflowOutflow
from ._utils import MovingWall

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

def _is_neumann(val):
    """ Returns true if x as units of stress """

    if not isinstance(val, u.Quantity):
        return False
    val = val.to_base_units()
    return val.units == u.kilogram / (u.meter * u.second**2)


class VelocityBCs(object):
    """ Class to define the mechanical boundary conditions """

    def __init__(self, Model, left=None, right=None, top=None, bottom=None,
                 front=None, back=None, indexSets=None,
                 order_wall_conditions=None):
        """ Defines mechanical boundary conditions

        The type of conditions is determined through the units used do define
        the parameters:
            * Units of velocity ([length] / [time]) represent a kinematic
            condition (Dirichlet)
            * Units of stress / pressure ([force] / [area]) are set as
            stress condition (Neumann).

        parameters
        ----------

        Model: (UWGeodynamics.Model)
            An UWGeodynamics Model (See UWGeodynamics.Model)
        left:(tuple) with length 2 in 2D, and length 3 in 3D.
            Define mechanical conditions on the left side of the Model.
            Conditions are defined for each Model direction (x, y, [z])
        right:(tuple) with length 2 in 2D, and length 3 in 3D.
            Define mechanical conditions on the right side of the Model.
            Conditions are defined for each Model direction (x, y, [z])
        top:(tuple) with length 2 in 2D, and length 3 in 3D.
            Define mechanical conditions on the top side of the Model.
            Conditions are defined for each Model direction (x, y, [z])
        bottom:(tuple) with length 2 in 2D, and length 3 in 3D.
            Define mechanical conditions on the bottom side of the Model.
            Conditions are defined for each Model direction (x, y, [z])
        indexSets: (list)
            List of node where to apply predefined velocities.

        Only valid for 3D Models:

        front:(tuple) with length 2 in 2D, and length 3 in 3D.
            Define mechanical conditions on the front side of the Model.
            Conditions are defined for each Model direction (x, y, [z])
        back:(tuple) with length 2 in 2D, and length 3 in 3D.
            Define mechanical conditions on the front side of the Model.
            Conditions are defined for each Model direction (x, y, [z])


        examples:
        ---------

        The following example defines a (2x1) meter Underworld model with
        freeslip conditions on all the sides.

        >>> import UWGeodynamics as GEO
        >>> u = GEO.u
        >>> Model = GEO.Model(elementRes=(64, 64),
                              minCoord=(-1. * u.meter, -50. * u.centimeter),
                              maxCoord=(1. * u.meter, 50. * u.centimeter))
        >>> Model.set_velocityBCs(left=[0, None], right=[0,None], top=[None,0],
                                  bottom=[None, 0])

        """

        self.Model = Model
        self.left = left
        self.right = right
        self.top = top
        self.bottom = bottom
        self.back = back
        self.front = front
        self.indexSets = indexSets

        self.dirichlet_indices = []
        self.neumann_indices = []


        if self.Model.mesh.dim == 2:
            self._wall_indexSets = {"bottom": (self.bottom,
                                               self.Model._bottom_wall),
                                    "top": (self.top,
                                            self.Model._top_wall),
                                    "left": (self.left,
                                             self.Model._left_wall),
                                    "right": (self.right,
                                              self.Model._right_wall)}
            if order_wall_conditions:
                if len(order_wall_conditions) <= 5:
                    self.order_wall_conditions = order_wall_conditions
            else:
                self.order_wall_conditions = ["bottom", "top", "left", "right"]

        if self.Model.mesh.dim == 3:
            self._wall_indexSets = {"bottom": (self.bottom,
                                               self.Model._bottom_wall),
                                    "top": (self.top,
                                            self.Model._top_wall),
                                    "left": (self.left,
                                             self.Model._left_wall),
                                    "right": (self.right,
                                              self.Model._right_wall),
                                    "front": (self.front,
                                              self.Model._front_wall),
                                    "back": (self.back,
                                             self.Model._back_wall)}
            if order_wall_conditions:
                if len(order_wall_conditions) <= 7:
                    self.order_wall_conditions = order_wall_conditions
            else:
                self.order_wall_conditions = ["bottom", "top", "front", "back",
                                              "left", "right"]

        # Link Moving Walls
        for arg in [self.left, self.right, self.top, self.bottom, self.front,
                    self.back]:
            if isinstance(arg, MovingWall):
                arg.Model = self.Model

    def __getitem__(self, name):
        return self.__dict__[name]

    def apply_condition_nodes(self, condition, nodes):
        """ Apply condition to a set of nodes

        Parameters:
        -----------
            condition:
                velocity condition
            nodes:
                set of nodes

        """

        if not nodes:
            return

        # Special case (Bottom LecodeIsostasy)
        if (isinstance(condition, LecodeIsostasy) and
            nodes ==  self.Model._bottom_wall):

            # Apply support condition
            self.Model._isostasy = self.bottom
            self.Model._isostasy.mesh = self.Model.mesh
            self.Model._isostasy.swarm = self.Model.swarm
            self.Model._isostasy._mesh_advector = self.Model._advector
            self.Model._isostasy.velocityField = self.Model.velocityField
            self.Model._isostasy.boundariesField = self.Model.boundariesField
            self.Model._isostasy.materialIndexField = self.Model.materialField
            self.Model._isostasy._densityFn = self.Model._densityFn
            vertical_walls_conditions = {
                "left": self.left,
                "right": self.right,
                "front": self.front,
                "back": self.back
            }
            self.Model._isostasy.vertical_walls_conditions = (
                vertical_walls_conditions)
            self.dirichlet_indices[-1] += self.Model._bottom_wall
            return

        #if isinstance(condition, MovingWall):
        #    condition.wall = nodes
        #    set_ = condition.get_wall_indices()
        #    velocity = nd(condition.velocity)
        #    dim = condition.wall_direction_axis[condition.wall]
        #    if set_.data.size > 0:
        #        self.Model.velocityField.data[set_.data, :] = 0.
        #        self.Model.boundariesField.data[set_.data, :] = 0.
        #        self.Model.velocityField.data[set_.data, dim] = velocity
        #        self.Model.boundariesField.data[set_.data, dim] = velocity
        #        self.dirichlet_indices[0] += set_
        #        self.dirichlet_indices[1] += set_
        #    return

        if isinstance(condition, MovingWall):
            condition.wall = nodes
            indices = condition.get_wall_indices()
            #velocity = nd(condition.velocity)
            func = condition.velocityFn
            for dim in range(self.Model.mesh.dim):
                set_ = indices[dim]
                if set_.data.size > 0:
                    self.Model.velocityField.data[set_.data, dim] =(
                        func.evaluate(set_)[:,0])
                    self.Model.boundariesField.data[set_.data, dim] =(
                        func.evaluate(set_)[:, 0])
                    self.dirichlet_indices[dim] += set_
            return


        # Expect a list or tuple of dimension mesh.dim.
        # Check that the domain actually contains some boundary nodes
        # (nodes is not None)
        if isinstance(condition, (list, tuple)) and nodes.data.size > 0:
            for dim in range(self.Model.mesh.dim):

                if isinstance(condition[dim], fn.Function):
                    func = condition[dim]
                    self.Model.velocityField.data[nodes.data, dim] = (
                        func.evaluate(
                            self.Model.mesh.data[nodes.data])[:, dim])
                    self.Model.boundariesField.data[nodes.data, dim] = (
                        func.evaluate(
                            self.Model.mesh.data[nodes.data])[:, dim])
                    self.dirichlet_indices[dim] += nodes

                # User defined function
                if isinstance(condition[dim], (list, tuple)):
                    func = fn.branching.conditional(condition[dim])
                    self.Model.velocityField.data[nodes.data, dim] = (
                        func.evaluate(
                            self.Model.mesh.data[nodes.data])[:, dim])
                    self.Model.boundariesField.data[nodes.data, dim] = (
                        func.evaluate(
                            self.Model.mesh.data[nodes.data])[:, dim])
                    self.dirichlet_indices[dim] += nodes

                # Scalar condition
                if isinstance(condition[dim], (u.Quantity, float, int)):

                    # Process dirichlet condition
                    if not _is_neumann(condition[dim]):
                        self.Model.velocityField.data[nodes.data, dim] = (
                            nd(condition[dim]))
                        self.Model.boundariesField.data[nodes.data, dim] = (
                            nd(condition[dim]))
                        self.dirichlet_indices[dim] += nodes
                    # Process neumann condition
                    else:
                        self.Model.tractionField.data[nodes.data, dim] = (
                            nd(condition[dim]))
                        self.neumann_indices[dim] += nodes

                # Inflow Outflow
                if isinstance(condition[dim], Balanced_InflowOutflow):
                    obj = condition[dim]
                    obj.ynodes = self.Model.mesh.data[nodes.data, 1]
                    obj._get_side_flow()
                    self.Model.velocityField.data[nodes.data, dim] = (
                        obj._get_side_flow())
                    self.Model.boundariesField.data[nodes.data, dim] = (
                        obj._get_side_flow())
                    self.dirichlet_indices[dim] += nodes

                if isinstance(condition[dim], LecodeIsostasy):
                    # Apply support condition
                    if self.Model.mesh.dim - 1 != dim:
                        raise ValueError("""Can not apply LecodeIsostasy on that
                                         dimension""")

                    self.Model._isostasy = condition[dim]
                    self.Model._isostasy.mesh = self.Model.mesh
                    self.Model._isostasy.swarm = self.Model.swarm
                    self.Model._isostasy._mesh_advector = self.Model._advector
                    self.Model._isostasy.velocityField = self.Model.velocityField
                    self.Model._isostasy.boundariesField = self.Model.boundariesField
                    self.Model._isostasy.materialIndexField = self.Model.materialField
                    self.Model._isostasy._densityFn = self.Model._densityFn
                    vertical_walls_conditions = {
                        "left": self.left,
                        "right": self.right,
                        "front": self.front,
                        "back": self.back
                    }
                    self.Model._isostasy.vertical_walls_conditions = (
                        vertical_walls_conditions)
                    self.dirichlet_indices[dim] += nodes

        return

    def get_conditions(self):
        """ Get the mechanical boundary conditions

        Returns
        -------

        List of conditions as:
            [<underworld.conditions._conditions.DirichletCondition,
             <underworld.conditions._conditions.NeumannCondition]
        or
            [<underworld.conditions._conditions.DirichletCondition]

        """

        Model = self.Model
        Model.boundariesField.data[...] = np.random.random(Model.boundariesField.data.shape)

        # Reinitialise neumnann and dirichlet condition
        self.dirichlet_indices = []
        self.neumann_indices = []

        for dim in range(Model.mesh.dim):
            self.dirichlet_indices.append(Model.mesh.specialSets["Empty"])
            self.neumann_indices.append(Model.mesh.specialSets["Empty"])

        for set_ in self.order_wall_conditions:
            (condition, nodes) = self._wall_indexSets[set_]
            self.apply_condition_nodes(condition, nodes)

        conditions = []

        conditions.append(uw.conditions.DirichletCondition(
            variable=Model.velocityField,
            indexSetsPerDof=self.dirichlet_indices))

        neumann_indices = []

        # Remove empty Sets
        for val in self.neumann_indices:
            if val.data.size > 0:
                neumann_indices.append(val)
            else:
                neumann_indices.append(None)
        self.neumann_indices = tuple(neumann_indices)

        # Now we only create a Neumann condition if we have a stress condition
        # somewhere, on any of the procs.
        local_procs_has_neumann = np.zeros((uw.nProcs()))
        global_procs_has_neumann = np.zeros((uw.nProcs()))
        if self.neumann_indices != tuple([None for val in range(Model.mesh.dim)]):
            local_procs_has_neumann[uw.rank()] = 1

        comm.Allreduce(local_procs_has_neumann, global_procs_has_neumann)
        comm.Barrier()

        if any(global_procs_has_neumann):
            conditions.append(uw.conditions.NeumannCondition(
                fn_flux=Model.tractionField,
                variable=Model.velocityField,
                indexSetsPerDof=self.neumann_indices))

        if not conditions:
            raise ValueError("Undefined conditions")

        return conditions

