#!python

from yt.mods import *
import os
namespace = locals().copy()

doc = """\

Welcome to Enzo-embedded yt!

The different processors are accessible via the 'mec' variable.  To get grid
data, try using the get_grid_field function.  When done, be sure to kill the
processes with 'mec.kill()'!

Information about the mec variable, an instance of MultiEngineClient, can be
found in the IPython documentation:

http://ipython.scipy.org/doc/manual/html/parallel/parallel_multiengine.html

You can use the '%px' command to issue commands on all the engines
simultaneously.

"""

import IPython.Shell

if "DISPLAY" in os.environ:
    try:
        ip_shell = IPython.Shell.IPShellMatplotlibWX(user_ns=namespace)
    except ImportError:
        ip_shell = IPython.Shell.IPShellMatplotlib(user_ns=namespace)
else:
    ip_shell = IPython.Shell.IPShellMatplotlib(user_ns=namespace)

ip = ip_shell.IP.getapi()

import os   
import glob
import itertools

ip = ip_shell.IP.getapi()
ip.ex("from yt.mods import *")
from IPython.kernel import client

class YTClient(object):
    mec = None

    def __init__(self):
        self.refresh()

    def eval(self, varname, targets = None):
        """
        This function pulls anything from the remote host, but it will overwrite
        any variable named __tmp.  This is to get around nested variables and
        properties on the remote host.
        """
        self.mec.execute("__tmp = %s" % varname, targets=targets)
        result = self.mec.pull("__tmp", targets=targets)
        return result

    def get_grid_field(self, grid_index, field_name, raw=False):
        """
        Return the numpy array representing a piece of field information.
        Note that *grid_index* is the actual index into the array, which is ID-1.

        If *raw* is set to True, then only raw original fields from the hierarchy
        are returned.  This will include ghost zones, and derived fields are
        inaccessible.
        """
        proc = int(self.enzo.hierarchy_information["GridProcs"][grid_index])
        if not raw: # go through yt
            result = self.eval("pf.h.grids[%s]['%s']" % (
                        grid_index, field_name), [proc])[0]
        else: # go through enzo module
            result = self.eval("enzo.grid_data[%s + 1]['%s']" % (
                        grid_index, field_name), [proc])[0].swapaxes(0,2)
        return result

    def refresh(self):
        if self.mec is not None: self.mec.kill()
        self.mec = client.MultiEngineClient()
        self.mec.activate()
        # there are some blocks in hierarchy instantiation, so
        # we pre-instantiate
        self.mec.execute("pf.h") 
        self.enzo = enzo_module_proxy(self)
        self.pf = EnzoStaticOutputProxy(ytc=self)
        ip.to_user_ns(dict(
            mec=self.mec, ytc=self, pf = self.pf))

class enzo_module_proxy(object):
    def __init__(self, ytc):
        self.hierarchy_information = ytc.eval("enzo.hierarchy_information", [0])[0]
        self.conversion_factors = ytc.eval("enzo.conversion_factors", [0])[0]
        self.yt_parameter_file = ytc.eval("enzo.yt_parameter_file", [0])[0]

from yt.lagos import EnzoStaticOutputInMemory, EnzoHierarchyInMemory
from yt.lagos.HierarchyType import _data_style_funcs
from yt.lagos.DataReadingFuncs import BaseDataQueue

class EnzoHierarchyProxy(EnzoHierarchyInMemory):
    _data_style = 'proxy'
    def _setup_field_lists(self):
        self.field_list = self.parameter_file.ytc.eval("pf.h.field_list", [0])[0]

    def _obtain_enzo(self):
        return self.parameter_file.ytc.enzo

class EnzoStaticOutputProxy(EnzoStaticOutputInMemory):
    _data_style = 'proxy'
    _hierarchy_class = EnzoHierarchyProxy

    def __init__(self, *args, **kwargs):
        self.ytc = kwargs.pop("ytc")
        EnzoStaticOutputInMemory.__init__(self, *args, **kwargs)

    def _obtain_enzo(self):
        return self.ytc.enzo

def _read_proxy_slice(self, grid, field, axis, coord):
    data = ytc.get_grid_field(grid.id - 1, field, raw=True)
    sl = [slice(3,-3), slice(3,-3), slice(3,-3)]
    sl[axis] = slice(coord + 3, coord + 4)
    sl = tuple(reversed(sl))
    return data[sl].swapaxes(0,2)

class DataQueueProxy(BaseDataQueue):
    def __init__(self, ghost_zones = 3):
        self.my_slice = (slice(ghost_zones, -ghost_zones),
                         slice(ghost_zones, -ghost_zones),
                         slice(ghost_zones, -ghost_zones))
        BaseDataQueue.__init__(self)

    def _read_set(self, grid, field):
        data = ytc.get_grid_field(grid.id - 1, field, raw=True)
        return data[self.my_slice]

    def modify(self, field):
        return field.swapaxes(0,2)

def proxy_exception(*args, **kwargs):
    return KeyError

# things like compare buffers over time

_data_style_funcs['proxy'] = \
    (None, None, None, _read_proxy_slice, proxy_exception, DataQueueProxy)

ytc = YTClient()
mec = ytc.mec

ip_shell.mainloop(sys_exit=1,banner=doc)
