#!/usr/bin/env python

# Copyright (C) 2012 Atsushi Togo
# All rights reserved.
#
# This file is part of phonopy.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in
#   the documentation and/or other materials provided with the
#   distribution.
#
# * Neither the name of the phonopy project nor the names of its
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys

try:
    import yaml
except ImportError:
    print("You need to install python-yaml.")
    sys.exit(1)
    
try:
    from yaml import CLoader as Loader
    from yaml import CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper

import numpy as np

def print_band(data):
    distance = 0
    text = "#%-9s %-10s %-10s    " % ("    q_a", "    q_b", "    q_c") 
    text += "%-12s    " % "   Distance"
    text += "%-12s %-12s     ..." % ("   Gruneisen", "   Frequency")
    lines = [text,]

    for path in data['path']:
        for q in path['phonon']:
            text = "%10.7f %10.7f %10.7f    " % (tuple(q['q-position']))
            text += "%12.7f    " % (q['distance'] + distance)
            for band in q['band']:
                text += "%12.7f %12.7f  " % (band['gruneisen'],
                                             band['frequency'])
            lines.append(text)
        distance += path['phonon'][-1]['distance']

    print("\n".join(lines))

def print_mesh(data):
    text = "#%-9s %-10s %-10s  " % ("    q_a", "    q_b", "    q_c")
    text += "%-12s  " % "Multiplicity"
    text += "%-12s %-12s     ..." % ("   Gruneisen", "   Frequency")
    lines = [text,]

    for qpt in data['phonon']:
        text = "%10.7f %10.7f %10.7f  " % (tuple(qpt['q-position']))
        text += "%12d  " % qpt['multiplicity']
        for band in qpt['band']:
            text += "%12.7f %12.7f  " % (band['gruneisen'], band['frequency'])
        lines.append(text)

    print("\n".join(lines))

def plot_band(data, is_fg, g_max, g_min):
    import matplotlib.pyplot as plt

    d = []
    g = []
    f = []
    distance = 0.0
    
    for path in data['path']:
        for q in path['phonon']:
            d.append(q['distance'] + distance)
            g.append([band['gruneisen'] for band in q['band']])
            f.append([band['frequency'] for band in q['band']])
        distance += path['phonon'][-1]['distance']

    if is_fg:
        ax1 = plt.subplot(3, 1, 1)
        plt.plot(d, g, '-')
        ax2 = plt.subplot(3, 1, 2)
        plt.plot(d, np.array(g) * np.array(f), '-')
        ax3 = plt.subplot(3, 1, 3)
        plt.plot(d, f, '-')
    else:
        ax1 = plt.subplot(2, 1, 1)
        plt.plot(d, g, '-')
        ax2 = plt.subplot(2, 1, 2)
        plt.plot(d, f, '-')

    if g_max is not None:
        ax1.set_ylim(ymax=g_max)
    if g_min is not None:
        ax1.set_ylim(ymin=g_min)
        
    return plt

def plot_yaml_mesh(data, is_fg, cutoff_frequency=None):
    x = []
    y = []
    for qpt in data['phonon']:
        x.append([band['frequency'] for band in qpt['band']])
        y.append([band['gruneisen'] for band in qpt['band']])

    return plot_mesh(x, y, is_fg, cutoff_frequency=cutoff_frequency)

def plot_hdf5_mesh(data, is_fg, cutoff_frequency=None):
    x = data['frequency']
    y = data['gruneisen'][:]
    return plot_mesh(x, y, is_fg, cutoff_frequency=cutoff_frequency)

def plot_mesh(x, y, is_fg, cutoff_frequency=None):
    import matplotlib.pyplot as plt
    for (g, freqs) in zip(np.transpose(y), np.transpose(x)):
        if cutoff_frequency:
            g = np.extract(freqs > cutoff_frequency, g)
            freqs = np.extract(freqs > cutoff_frequency, freqs)
    
        if is_fg:
            plt.plot(freqs, np.array(freqs) * np.array(g), 'o')
        else:
            plt.plot(freqs, g, 'o')

    return plt

def get_thermodynamic_Gruneisen_parameter_yaml(data, cutoff_frequency=None):
    x = []
    y = []
    m = []
    for qpt in data['phonon']:
        m.append(qpt['multiplicity'])
        x.append([band['frequency'] for band in qpt['band']])
        y.append([band['gruneisen'] for band in qpt['band']])
    get_thermodynamic_Gruneisen_parameter(x, y, m,
                                          cutoff_frequency=cutoff_frequency)
    
def get_thermodynamic_Gruneisen_parameter_hdf5(data, cutoff_frequency=None):
    x = data['frequency'][:]
    y = data['gruneisen'][:]
    m = data['weight'][:]
    get_thermodynamic_Gruneisen_parameter(x, y, m,
                                          cutoff_frequency=cutoff_frequency)

def get_thermodynamic_Gruneisen_parameter(x, y, m,
                                          cutoff_frequency=None):
    freqs = np.array(x)
    if cutoff_frequency:
        freqs = np.where(freqs > cutoff_frequency, x, -1)
    gammas = np.array(y)
    weights = np.array(m)

    from phonopy.gruneisen.mesh import get_thermodynamic_Gruneisen_parameter as gtgp
    print("Thermodynamic gruneisen parameter at 300K: %s" %
          gtgp(gammas, freqs, weights, 300))

# Parse options
from optparse import OptionParser
parser = OptionParser()
parser.set_defaults(f_max=None,
                    f_min=None,
                    g_max=None,
                    g_min=None,
                    is_hdf5=False,
                    is_fg=False,
                    is_gnuplot=False,
                    cutoff_frequency=None,
                    output_filename=None,
                    title=None)
parser.add_option("--gnuplot", dest="is_gnuplot", action="store_true",
                  help="Output in gnuplot data style")
parser.add_option("-o", "--output", dest="output_filename",
                  action="store", type="string",
                  help="Output filename of PDF plot")
parser.add_option("--fg", dest="is_fg", action="store_true",
                  help="Plot omega x gamma")
parser.add_option("-t", "--title", dest="title", action="store",
                  type="string", help="Title of plot")
parser.add_option("--cutoff", dest="cutoff_frequency", type="float",
                  help=("Plot above this cutoff frequency for mesh sampling "
                        "mode."))
parser.add_option("--fmax", dest="f_max", type="float",
                  help="Maximum frequency plotted")
parser.add_option("--fmin", dest="f_min", type="float",
                  help="Minimum frequency plotted")
parser.add_option("--gmax", dest="g_max", type="float",
                  help="Maximum Gruneisen params plotted")
parser.add_option("--gmin", dest="g_min", type="float",
                  help="Minimum Gruneisen params plotted")
parser.add_option("--hdf5", dest="is_hdf5", action="store_true",
                  help="Use hdf5 to read results")
(options, args) = parser.parse_args()

if options.is_hdf5:
    try:
        import h5py
    except ImportError:
        print("You need to install python-h5py.")
        sys.exit(1)

    if len(args) == 0:
        filename = "gruneisen.hdf5"
    else:
        filename = args[0]
    data = h5py.File(filename)
else:
    if len(args) == 0:
        filename = "gruneisen.yaml"
    else:
        filename = args[0]
    with open(filename) as f:
        data = yaml.load(f.read(), Loader=Loader)

if 'path' in data:
    if options.is_gnuplot:
        print_band(data)
    else:
        plt = plot_band(data, options.is_fg, options.g_max, options.g_min)
        if options.title is not None:
            plt.subtitle(options.title)
        plt.show()

if 'mesh' in data:
    if options.is_hdf5:
        plt = plot_hdf5_mesh(data,
                             options.is_fg,
                             cutoff_frequency=options.cutoff_frequency)
    else:
        if options.is_gnuplot:
            print_mesh(data)
        else:
            plt = plot_yaml_mesh(data,
                                 options.is_fg,
                                 cutoff_frequency=options.cutoff_frequency)

    if options.title is not None:
        plt.subtitle(options.title)

    if options.f_max is not None:
        plt.xlim(xmax=options.f_max)
    if options.f_min is not None:
        plt.xlim(xmin=options.f_min)
    if options.g_max is not None:
        plt.ylim(ymax=options.g_max)
    if options.g_min is not None:
        plt.ylim(ymin=options.g_min)

    if options.is_hdf5:
        get_thermodynamic_Gruneisen_parameter_hdf5(
            data, cutoff_frequency=options.cutoff_frequency)
    else:
        get_thermodynamic_Gruneisen_parameter_yaml(
            data, cutoff_frequency=options.cutoff_frequency)
        
    plt.show()
