#!/home/travis/miniconda/bin/python

# Copyright (C) 2011 Atsushi Togo
# All rights reserved.
#
# This file is part of phonopy.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in
#   the documentation and/or other materials provided with the
#   distribution.
#
# * Neither the name of the phonopy project nor the names of its
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys
import numpy as np

try:
    import yaml
except ImportError:
    print("You need to install python-yaml.")
    sys.exit(1)

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader

from phonopy.units import VaspToTHz

def read_band_yaml(filename):
    data = yaml.load(open(filename), Loader=Loader)
    frequencies = []
    distances = []
    labels = []
    for j, v in enumerate(data['phonon']):
        if 'label' in v:
            labels.append(v['label'])
        else:
            labels.append(None)
        frequencies.append([f['frequency'] for f in v['band']])
        distances.append(v['distance'])

    return (np.array(distances),
            np.array(frequencies),
            data['segment_nqpoint'],
            labels)

def read_dos_dat(filename, pdos_indices=None, dos_factor=None):
    dos = []
    frequencies = []
    for line in open(filename):
        if line.strip()[0] == '#':
            continue
        ary = [float(x) for x in line.split()]
        frequencies.append(ary.pop(0))
        dos.append(ary)
    dos = np.array(dos)
    frequencies = np.array(frequencies)

    if pdos_indices:
        pi = []
        for nums in pdos_indices.split(','):
            pi.append([int(x) - 1 for x in nums.split()])
        dos_sum = []
        for indices in pi:
            dos_sum.append(dos[:, indices].sum(axis=1))
        dos_sum.append(dos.sum(axis=1))
        dos = np.transpose(dos_sum)

    if dos_factor:
        dos *= dos_factor

    return frequencies, dos

def get_options():
    # Parse options
    import argparse
    parser = argparse.ArgumentParser(
        description="Phonopy bandplot command-line-tool")
    parser.set_defaults(band_labels=None,
                        dos=None,
                        dos_max=None,
                        dos_min=None,
                        dos_factor=None,
                        factor=1.0,
                        f_max=None,
                        f_min=None,
                        is_gnuplot=False,
                        is_points=False,
                        is_vertial_line=False,
                        output_filename=None,
                        pdos_indices=None,
                        xlabel=None,
                        ylabel=None,
                        show_legend=False,
                        title=None)
    parser.add_argument(
        "--band-labels", dest="band_labels",
        help="Show labels at band segments")
    parser.add_argument(
        "--dmax", dest="dos_max", type=float,
        help="Maximum DOS plotted")
    parser.add_argument(
        "--dmin", dest="dos_min", type=float,
        help="Minimum DOS plotted")
    parser.add_argument(
        "--dos", dest="dos",
        help="Read dos.dat type file and plot with band structure")
    parser.add_argument(
        "--dos-factor", dest="dos_factor", type=float,
        help="Factor to be multiplied with DOS.")
    parser.add_argument(
        "--factor", dest="factor", type=float,
        help="Conversion factor to favorite frequency unit")
    parser.add_argument(
        "--fmax", dest="f_max", type=float,
        help="Maximum frequency plotted")
    parser.add_argument(
        "--fmin", dest="f_min", type=float,
        help="Minimum frequency plotted")
    parser.add_argument(
        "--gnuplot", dest="is_gnuplot", action="store_true",
        help="Output in gnuplot data style")
    parser.add_argument(
        "-i", "--indices", dest="pdos_indices",
        help="Indices like 1 2, 3 4 5 6...")
    parser.add_argument(
        "--legend", dest="show_legend", action="store_true",
        help="Show legend")
    parser.add_argument(
        "--line", "-l", dest="is_vertial_line", action="store_true",
        help="Vertial line is drawn at between paths")
    parser.add_argument(
        "-o", "--output", dest="output_filename", action="store",
        help="Output filename of PDF plot")
    parser.add_argument(
        "--xlabel", dest="xlabel",
        help="Specify x-label")
    parser.add_argument(
        "--ylabel", dest="ylabel",
        help="Specify y-label")
    parser.add_argument(
        "--points", dest="is_points", action="store_true",
        help="Draw points")
    parser.add_argument(
        "-t", "--title", dest="title",
        help="Title of plot")
    parser.add_argument(
        "filenames", nargs='*',
        help="Filenames of phonon band structure result (band.yaml)")
    args = parser.parse_args()
    return args

def main(args):
    if args.output_filename:
        import matplotlib
        matplotlib.use('Agg')

    if not args.is_gnuplot:
        import matplotlib.pyplot as plt
        if args.band_labels:
            from matplotlib import rc
            rc('text', usetex=True)
        if args.dos:
            import matplotlib.gridspec as gridspec
            plt.figure(figsize=(10, 6))
            gs = gridspec.GridSpec(1, 2, width_ratios=[3, 1])
            ax1 = plt.subplot(gs[0, 0])
            ax1.xaxis.set_ticks_position('both')
            ax1.yaxis.set_ticks_position('both')
            ax1.xaxis.set_tick_params(which='both', direction='in')
            ax1.yaxis.set_tick_params(which='both', direction='in')
        else:
            fig, ax = plt.subplots()
            ax.xaxis.set_ticks_position('both')
            ax.yaxis.set_ticks_position('both')
            ax.xaxis.set_tick_params(which='both', direction='in')
            ax.yaxis.set_tick_params(which='both', direction='in')

    colors = ['r-', 'b-', 'g-', 'c-', 'm-', 'y-', 'k-',
              'r--', 'b--', 'g--', 'c--', 'm--', 'y--', 'k--']
    if args.is_points:
        colors = [x + 'o' for x in colors]

    count = 0


    if len(args.filenames) == 0:
        filenames = ['band.yaml']
    else:
        filenames = args.filenames

    if args.is_gnuplot:
        print("# distance  frequency (bands are separated by blank lines)")

    if args.dos:
        dos_frequencies, dos = read_dos_dat(args.dos,
                                            pdos_indices=args.pdos_indices,
                                            dos_factor=args.dos_factor)

    curves = []
    for i, filename in enumerate(filenames):
        (distances,
         frequencies,
         segment_nqpoint,
         labels) = read_band_yaml(filename)

        end_points = [0,]
        for nq in segment_nqpoint:
            end_points.append(nq + end_points[-1])
        end_points[-1] -= 1
        segment_positions = distances[end_points]

        if all(x is None for x in labels):
            labels_at_ends = None
        else:
            labels_at_ends = [labels[n] for n in end_points]

        if args.is_gnuplot:
            print("# End points of segments: ")
            print("#   " + "%10.8f " * len(segment_positions) %
                  tuple(segment_positions))
        elif args.is_vertial_line and len(filenames) == 1:
            for v in segment_positions[1:-1]:
                plt.axvline(x=v, linewidth=0.5, color='b')

        if args.is_gnuplot:
            for j, freqs in enumerate(frequencies.T):
                q = 0
                for nq in segment_nqpoint:
                    for d, f in zip(distances[q:(q + nq)],
                                    freqs[q:(q + nq)] * args.factor):
                        print("%f %f" % (d, f))
                    q += nq
                    print('')
                print('')
        else:
            q = 0
            for j, nq in enumerate(segment_nqpoint):
                if j == 0:
                    curves.append(
                        plt.plot(distances[q:(q + nq)],
                                 frequencies[q:(q + nq)] * args.factor,
                                 colors[i],
                                 label=filename)[0])
                else:
                    plt.plot(distances[q:(q + nq)],
                             frequencies[q:(q + nq)] * args.factor,
                             colors[i])
                q += nq

        if args.is_gnuplot:
            print('')


    if not args.is_gnuplot:
        if args.xlabel is None:
            plt.xlabel('Wave vector')
        else:
            plt.xlabel(args.xlabel)
        if args.ylabel is None:
            plt.ylabel('Frequency')
        else:
            plt.ylabel(args.ylabel)

        plt.xlim(distances[0], distances[-1])
        if args.f_max is not None:
            plt.ylim(top=args.f_max)
        if args.f_min is not None:
            plt.ylim(bottom=args.f_min)
        plt.axhline(y=0, linestyle=':', linewidth=0.5, color='b')
        if len(filenames) == 1:
            xticks = segment_positions
            if args.band_labels:
                band_labels = [x for x in args.band_labels.split()]
                if len(band_labels) == len(xticks):
                    plt.xticks(xticks, band_labels)
                else:
                    print("Numbers of labels and band segments don't match.")
                    sys.exit(1)
            elif labels_at_ends:
                plt.xticks(xticks, labels_at_ends)
            else:
                plt.xticks(xticks, [''] * len(xticks))
        else:
            plt.xticks([])

        if args.title is not None:
            plt.title(args.title)

        if args.show_legend:
            plt.legend(handles=curves)

        if args.dos:
            arg_fmax = len(dos_frequencies)
            if args.f_max is not None:
                for i, f in enumerate(dos_frequencies):
                    if f > args.f_max:
                        arg_fmax = i
                        break
            arg_fmin = 0
            if args.f_min is not None:
                for i, f in enumerate(dos_frequencies):
                    if f > args.f_min:
                        if i > 0:
                            arg_fmin = i - 1
                        break
            ax2 = plt.subplot(gs[0, 1], sharey=ax1)
            ax2.xaxis.set_ticks_position('both')
            ax2.yaxis.set_ticks_position('both')
            ax2.xaxis.set_tick_params(which='both', direction='in')
            ax2.yaxis.set_tick_params(which='both', direction='in')

            plt.subplots_adjust(wspace=0.03)
            plt.setp(ax2.get_yticklabels(), visible=False)

            for pdos in dos.T:
                if args.dos_max:
                    _pdos = _plot_dos(pdos[arg_fmin:arg_fmax],
                                      dos_frequencies[arg_fmin:arg_fmax],
                                      args.dos_max)
                    plt.plot(_pdos[1], _pdos[0])
                else:
                    plt.plot(pdos[arg_fmin:arg_fmax],
                             dos_frequencies[arg_fmin:arg_fmax])
            plt.xlabel('DOS')

            ax2.set_xlim((0, None))

            if args.f_max is not None:
                plt.ylim(top=args.f_max)
            if args.f_min is not None:
                plt.ylim(bottom=args.f_min)
            if args.dos_max is not None:
                plt.xlim(right=args.dos_max)
            if args.dos_min is not None:
                plt.xlim(left=args.dos_min)

        if args.output_filename is not None:
            plt.rcParams['pdf.fonttype'] = 42
            plt.rcParams['font.family'] = 'serif'
            plt.savefig(args.output_filename)
        else:
            plt.show()

def _plot_dos(d, f, dmax):
    pdos = []
    for f1, f2, d1, d2 in zip(f[:-1], f[1:], d[:-1], d[1:]):
        pdos += _cut_dos([f1, d1], [f2, d2], dmax)
    return np.transpose(pdos)

def _cut_dos(p1, p2, dmax):
    """Cut DOS at dmax

    ** Cutting by fmin and fmax should be implemented someday. **

    Args:
        p1: (f1, d1)
        p2: (f2, d2)

    When d1 < dmax and d2 < dmax
        p1: (f1, d1)
        p2: (f2, d2)
    When d1 < dmax and d2 > dmax
        p1: (f1, d1)
        pi: (fi, dmax), f1 < fi < f2
        p2: (f2, dmax)
    Whend1 > dmax and d2 < dmax,
        p1: (f1, dmax)
        pi: (fi, dmax), f1 < fi < f2
        p2: (f2, d2)
    When d2 > dmax and d1 > dmax
        p1: (f1, dmax)
        p2: (f2, dmax)

    """

    def _get_fi(p1, p2, dmax):
        df = p2[0] - p1[0]
        dd = p2[1] - p1[1]
        fi = (dmax - p1[1]) / dd * df + p1[0]
        return fi

    if p1[1] < dmax and p2[1] < dmax:
        return [p1, p2]
    elif p1[1] < dmax and p2[1] > dmax:
        fi = _get_fi(p1, p2, dmax)
        return [p1, [fi, dmax], [p2[0], dmax]]
    elif p1[1] > dmax and p2[1] < dmax:
        fi = _get_fi(p1, p2, dmax)
        return [[p1[0], dmax], [fi, dmax], p2]
    else:
        return [[p1[0], dmax], [p1[0], dmax]]

if __name__ == "__main__":
    main(get_options())
