#!/home/travis/miniconda/bin/python

# Copyright (C) 2012 Atsushi Togo
# All rights reserved.
#
# This file is part of phonopy.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in
#   the documentation and/or other materials provided with the
#   distribution.
#
# * Neither the name of the phonopy project nor the names of its
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys
import numpy as np
from phonopy import Phonopy
from phonopy import PhonopyGruneisen
from phonopy.interface import (read_crystal_structure,
                               get_default_cell_filename,
                               get_default_physical_units)
from phonopy.file_IO import (parse_FORCE_SETS, parse_BORN,
                             parse_FORCE_CONSTANTS, read_force_constants_hdf5)
from phonopy.units import VaspToTHz

def fracval(frac):
    if frac.find('/') == -1:
        return float(frac)
    else:
        x = frac.split('/')
        return float(x[0]) / float(x[1])

def get_cell(cell_filename, interface_mode, yaml_mode=False):
    cell, optional_structure_file_information = read_crystal_structure(
        filename=cell_filename,
        interface_mode=interface_mode,
        yaml_mode=yaml_mode)
    if cell is None:
        print("Crystal structure file of %s could not be found." %
              optional_structure_file_information[0])
        sys.exit(1)

    return cell

def get_phonon(cell_filename,
               force_sets_filename,
               dim,
               primitive_matrix=np.eye(3),
               born_filename=None,
               physical_units=None,
               interface_mode='vasp',
               yaml_mode=False,
               symprec=1e-5):
    cell = get_cell(cell_filename, interface_mode, yaml_mode=yaml_mode)
    phonon = Phonopy(cell,
                     dim,
                     primitive_matrix=primitive_matrix,
                     factor=physical_units['factor'],
                     symprec=symprec)
    force_sets = parse_FORCE_SETS(filename=force_sets_filename)
    phonon.set_displacement_dataset(force_sets)
    phonon.produce_force_constants()
    if born_filename:
        nac_params = parse_BORN(phonon.get_primitive(),
                                filename=born_filename,
                                symprec=symprec)
        if nac_params['factor'] is None:
            nac_params['factor'] = physical_units['nac_factor']
        phonon.set_nac_params(nac_params)
    phonon.set_dynamical_matrix()
    return phonon

def get_phonon_from_force_constants(cell_filename,
                                    force_constants_filename,
                                    dim,
                                    primitive_matrix=np.eye(3),
                                    born_filename=None,
                                    physical_units=None,
                                    interface_mode='vasp',
                                    symprec=1e-5,
                                    is_hdf5=False):
    cell = get_cell(cell_filename, interface_mode)
    phonon = Phonopy(cell,
                     dim,
                     primitive_matrix=primitive_matrix,
                     factor=physical_units['factor'],
                     symprec=symprec)
    if is_hdf5:
        force_constants = read_force_constants_hdf5(
            filename=force_constants_filename)
    else:
        force_constants = parse_FORCE_CONSTANTS(
            filename=force_constants_filename)
    phonon.set_force_constants(force_constants)
    if born_filename:
        nac_params = parse_BORN(phonon.get_primitive(),
                                filename=born_filename,
                                symprec=symprec)
        if nac_params['factor'] == None:
            nac_params['factor'] = physical_units['nac_factor']
        phonon.set_nac_params(nac_params)
    phonon.set_dynamical_matrix()
    return phonon

def get_options():
    import argparse
    parser = argparse.ArgumentParser(
        description="Phonopy gruneisen command-line-tool")
    parser.set_defaults(
        abinit_mode=False,
        band_paths=None,
        cell_filename=None,
        color_scheme=None,
        crystal_mode=False,
        cutoff_frequency=None,
        cutoff_wave_vector=1e-4,
        factor=None,
        is_gamma_center=False,
        is_hdf5=False,
        is_nac=False,
        is_mesh_symmetry=True,
        plot_graph=False,
        reads_force_constants=False,
        band_points=51,
        marker='o',
        markersize=None,
        primitive_axis=None,
        qe_mode=False,
        sampling_mesh=None,
        save_graph=False,
        siesta_mode=False,
        supercell_dimension=None,
        symprec=1e-5,
        title=None,
        tmax=2004,
        tmin=0,
        tstep=2,
        volumes_filename=None,
        wien2k_mode=False,
        yaml_mode=False)
    parser.add_argument(
        "--abinit", dest="abinit_mode", action="store_true",
        help="Invoke Abinit mode")
    parser.add_argument(
        "--band", dest="band_paths",
        help="Band paths in reduced coordinates")
    parser.add_argument(
        "--band_points", dest="band_points", type=int,
        help="Number of sampling points in a segment of band path")
    parser.add_argument(
        "-c", "--cell", dest="cell_filename",
        help="Read unit cell", metavar="FILE")
    parser.add_argument(
        "--color", dest="color_scheme",
        help="Color scheme")
    parser.add_argument(
        "--crystal", dest="crystal_mode", action="store_true",
        help="Invoke CRYSTAL mode")
    parser.add_argument(
        "--cutoff", dest="cutoff_frequency", type=float,
        help="Plot above this cutoff frequency for mesh sampling mode.")
    parser.add_argument(
        "--dim", dest="supercell_dimension",
        help="Same behavior as DIM tag")
    parser.add_argument(
        "--elk", dest="elk_mode", action="store_true",
        help="Invoke Elk mode")
    parser.add_argument(
        "--factor", dest="factor", type=float,
        help="Conversion factor to favorite frequency unit")
    parser.add_argument(
        "--hdf5", dest="is_hdf5", action="store_true",
        help="Use hdf5 to read force constants and store results")
    parser.add_argument(
        "--gc", "--gamma_center", dest="is_gamma_center", action="store_true",
        help="Set mesh as Gamma center")
    parser.add_argument(
        "--marker", dest="marker",
        help="Marker for plot (matplotlib)")
    parser.add_argument(
        "--markersize", dest="markersize", type=float,
        help="Markersize for plot in points (matplotlib)")
    parser.add_argument(
        "--mp", "--mesh", dest="sampling_mesh",
        help="Sampling mesh")
    parser.add_argument(
        "--nac", dest="is_nac", action="store_true",
        help="Non-analytical term correction")
    parser.add_argument(
        "--nomeshsym", dest="is_mesh_symmetry", action="store_false",
        help="Symmetry is not imposed for mesh sampling.")
    parser.add_argument(
        "-o", "--output", dest="output_filename",
        help="Output filename of PDF plot")
    parser.add_argument(
        "-p", "--plot", dest="plot_graph", action="store_true",
        help="Plot data")
    parser.add_argument(
        "--pa", "--primitive_axis", dest="primitive_axis",
        help="Same as PRIMITIVE_AXIS tags")
    parser.add_argument(
        "--qe", "--pwscf", dest="qe_mode", action="store_true",
        help="Invoke Quantum espresso mode")
    parser.add_argument(
        "--q_cutoff", dest="cutoff_wave_vector", type=float,
        help="Acoustic modes inside cutoff wave vector is treated.")
    parser.add_argument(
        "--readfc", dest="reads_force_constants", action="store_true",
        help="Read FORCE_CONSTANTS")
    parser.add_argument(
        "-s", "--save", dest="save_graph", action="store_true",
        help="Save plot data in pdf")
    parser.add_argument(
        "--siesta", dest="siesta_mode", action="store_true",
        help="Invoke Siesta mode")
    parser.add_argument(
        "-t", "--title", dest="title",
        help="Title of plot")
    parser.add_argument(
        "--tmax", dest="tmax", type=float,
        help="Maximum calculated temperature")
    parser.add_argument(
        "--tmin", dest="tmin", type=float,
        help="Minimum calculated temperature")
    parser.add_argument(
        "--tolerance", dest="symprec", type=float,
        help="Symmetry tolerance to search")
    parser.add_argument(
        "--tstep", dest="tstep", type=float,
        help="Calculated temperature step")
    parser.add_argument(
        "--vf", "--volumes_filename", dest="volumes_filename",
        help="Filename of volume is contained.")
    parser.add_argument(
        "--wien2k", dest="wien2k_mode", action="store_true",
        help="Invoke Wien2k mode")
    parser.add_argument(
        "--yaml", dest="yaml_mode", action="store_true",
        help="Activate phonopy YAML mode")
    parser.add_argument(
        "dirnames", nargs='*',
        help=("Directory names of phonons with three different volumes "
              "(0, +, -)"))
    args = parser.parse_args()
    return args

def main(args):
    import warnings
    warnings.filterwarnings('error')

    if args.is_hdf5:
        try:
            import h5py
        except ImportError:
            print("You need to install python-h5py.")
            sys.exit(1)

    if len(args.dirnames) is not 3:
        sys.stderr.write("Three directory names (original, plus, minus) "
                         "have to be spefied.\n")
        sys.exit(1)

    if args.primitive_axis:
        primitive_axis = np.array(
            [fracval(x) for x in args.primitive_axis.split()]).reshape(3, 3)
    else:
        primitive_axis = np.eye(3)

    if args.supercell_dimension:
        dim = [int(x) for x in args.supercell_dimension.split()]
        if len(dim) == 9:
            dim = np.array(dim).reshape(3, 3)
        elif len(dim) == 3:
            dim = np.diag(dim)
        else:
            print("Number of elements of --dim option has to be 3 or 9.")
            sys.exit(1)
        if np.linalg.det(dim) < 1:
            print("Determinant of supercell matrix has to be positive.")
            sys.exit(1)
    else:
        print("--dim option has to be specified.")
        sys.exit(1)


    #
    # Phonopy gruneisen interface mode
    #
    if args.wien2k_mode:
        interface_mode = 'wien2k'
    elif args.abinit_mode:
        interface_mode = 'abinit'
    elif args.qe_mode:
        interface_mode = 'qe'
    elif args.elk_mode:
        interface_mode = 'elk'
    elif args.siesta_mode:
        interface_mode = 'siesta'
    elif args.crystal_mode:
        interface_mode = 'crystal'
    else:
        interface_mode = 'vasp'
    physical_units = get_default_physical_units(interface_mode)


    if args.factor is not None:
        physical_units['factor'] = args.factor

    if args.cell_filename:
        cell_filename = args.cell_filename
    else:
        cell_filename = get_default_cell_filename(interface_mode)


    phonons = []
    for i in range(3):
        directory = args.dirnames[i]
        if args.is_nac:
            born_filename = "%s/BORN" % directory
        else:
            born_filename = None

        if args.reads_force_constants:
            if args.is_hdf5:
                phonon = get_phonon_from_force_constants(
                    "%s/%s" % (directory, cell_filename),
                    "%s/force_constants.hdf5" % directory,
                    dim,
                    primitive_matrix=primitive_axis,
                    born_filename=born_filename,
                    physical_units=physical_units,
                    interface_mode=interface_mode,
                    symprec=args.symprec,
                    is_hdf5=True)
            else:
                phonon = get_phonon_from_force_constants(
                    "%s/%s" % (directory, cell_filename),
                    "%s/FORCE_CONSTANTS" % directory,
                    dim,
                    primitive_matrix=primitive_axis,
                    born_filename=born_filename,
                    physical_units=physical_units,
                    interface_mode=interface_mode,
                    symprec=args.symprec)
            phonons.append(phonon)
        else:
            phonon = get_phonon("%s/%s" % (directory, cell_filename),
                                "%s/FORCE_SETS" % directory,
                                dim,
                                primitive_matrix=primitive_axis,
                                born_filename=born_filename,
                                physical_units=physical_units,
                                interface_mode=interface_mode,
                                yaml_mode=args.yaml_mode,
                                symprec=args.symprec)
            phonons.append(phonon)

    gruneisen = PhonopyGruneisen(phonons[0], # equilibrium
                                 phonons[1], # plus
                                 phonons[2]) # minus

    if args.plot_graph:
        if args.save_graph:
            import matplotlib as mpl
            mpl.use('Agg')

    if args.band_paths:
        from phonopy.phonon.band_structure import get_band_qpoints

        band_paths = []
        for path_str in args.band_paths.split(','):
            paths = np.array([fracval(x) for x in path_str.split()])
            if len(paths) % 3 != 0 or len(paths) < 6:
                print("Band path is incorrectly set.")
                sys.exit(1)
            band_paths.append(paths.reshape(-1, 3))
        bands = get_band_qpoints(band_paths, args.band_points)
        gruneisen.set_band_structure(bands)
        gruneisen.write_yaml_band_structure()
        if args.plot_graph:
            plt = gruneisen.plot_band_structure(epsilon=args.cutoff_wave_vector,
                                                color_scheme=args.color_scheme)
            if args.title is not None:
                plt.suptitle(args.title)

    elif args.sampling_mesh:
        mesh_numbers = np.array([int(x) for x in args.sampling_mesh.split()])
        gruneisen.set_mesh(mesh_numbers,
                           is_gamma_center=args.is_gamma_center,
                           is_mesh_symmetry=args.is_mesh_symmetry)

        if args.is_hdf5:
            gruneisen.write_hdf5_mesh()
        else:
            gruneisen.write_yaml_mesh()

        if args.plot_graph:
            plt = gruneisen.plot_mesh(cutoff_frequency=args.cutoff_frequency,
                                      color_scheme=args.color_scheme,
                                      marker=args.marker,
                                      markersize=args.markersize)
            if args.title is not None:
                plt.suptitle(args.title)

        if args.volumes_filename is not None:
            volumes = []
            with open(args.volumes_filename) as f:
                for line in f:
                    if line.strip()[0] == '#':
                        continue
                    volumes.append(float(line.split()[0]))
            if volumes:
                gruneisen.set_thermal_properties(
                    volumes,
                    t_step=args.tstep,
                    t_max=args.tmax,
                    t_min=args.tmin,
                    cutoff_frequency=args.cutoff_frequency)
                gruneisen.write_yaml_thermal_properties()
    else:
        pass

    if args.plot_graph:
        if args.save_graph:
            if args.output_filename:
                plt.savefig(args.output_filename)
            else:
                plt.savefig("gruneisen.pdf")
        else:
            plt.show()

if __name__ == "__main__":
    main(get_options())
