#!/home/travis/miniconda/bin/python

# Copyright (C) 2012 Atsushi Togo
# All rights reserved.
#
# This file is part of phonopy.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in
#   the documentation and/or other materials provided with the
#   distribution.
#
# * Neither the name of the phonopy project nor the names of its
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import sys
from phonopy import Phonopy
from phonopy import PhonopyGruneisen
from phonopy.interface import read_crystal_structure, get_default_cell_filename
from phonopy.file_IO import (parse_FORCE_SETS, parse_BORN,
                             parse_FORCE_CONSTANTS, read_force_constants_hdf5)
from phonopy.units import VaspToTHz, Wien2kToTHz, AbinitToTHz, PwscfToTHz
import numpy as np

def fracval(frac):
    if frac.find('/') == -1:
        return float(frac)
    else:
        x = frac.split('/')
        return float(x[0]) / float(x[1])

def get_cell(cell_filename,
             interface_mode):
    cell, optional_structure_file_information = read_crystal_structure(
        filename=cell_filename,
        interface_mode=interface_mode,
        yaml_mode=options.yaml_mode)
    if cell is None:
        print("Crystal structure file of %s could not be found." %
              optional_structure_file_information[0])
        sys.exit(1)

    return cell
        
def get_phonon(cell_filename,
               force_sets_filename,
               dim,
               primitive_matrix=np.eye(3),
               born_filename=None,
               factor=VaspToTHz,
               interface_mode='vasp',
               symprec=1e-5):
    cell = get_cell(cell_filename, interface_mode)
    phonon = Phonopy(cell,
                     dim,
                     primitive_matrix=primitive_matrix,
                     factor=factor,
                     symprec=symprec)
    force_sets = parse_FORCE_SETS(filename=force_sets_filename)
    phonon.set_displacement_dataset(force_sets)
    phonon.produce_force_constants()
    if born_filename:
        nac_params = parse_BORN(phonon.get_primitive(),
                                filename=born_filename,
                                symprec=symprec)
        phonon.set_nac_params(nac_params)
    phonon.set_dynamical_matrix()
    return phonon

def get_phonon_from_force_constants(cell_filename,
                                    force_constants_filename,
                                    dim,
                                    primitive_matrix=np.eye(3),
                                    born_filename=None,
                                    factor=VaspToTHz,
                                    interface_mode='vasp',
                                    symprec=1e-5,
                                    is_hdf5=False):
    cell = get_cell(cell_filename, interface_mode)
    phonon = Phonopy(cell,
                     dim,
                     primitive_matrix=primitive_matrix,
                     factor=factor,
                     symprec=symprec)
    if is_hdf5:
        force_constants = read_force_constants_hdf5(
            filename=force_constants_filename)
    else:
        force_constants = parse_FORCE_CONSTANTS(
            filename=force_constants_filename)
    phonon.set_force_constants(force_constants)
    if born_filename:
        nac_params = parse_BORN(phonon.get_primitive(),
                                filename=born_filename,
                                symprec=symprec)
        phonon.set_nac_params(nac_params)
    phonon.set_dynamical_matrix()
    return phonon

from optparse import OptionParser
parser = OptionParser()
parser.set_defaults(
    abinit_mode=False,
    band_paths=None,
    cell_filename=None,
    color_scheme=None,
    cutoff_frequency=None,
    cutoff_wave_vector=1e-4,
    factor=None,
    is_gamma_center=False,
    is_hdf5=False,
    is_nac=False,
    is_mesh_symmetry=True,
    plot_graph=False,
    reads_force_constants=False,                    
    band_points=51,
    marker='o',
    markersize=None,
    primitive_axis=None,
    pwscf_mode=False,
    sampling_mesh=None,
    save_graph=False,
    supercell_dimension=None,
    symprec=1e-5,
    title=None,
    tmax=2004,
    tmin=0,
    tstep=2,
    volumes_filename=None,
    wien2k_mode=False,
    yaml_mode=False)
parser.add_option("--abinit", dest="abinit_mode",
                  action="store_true", help="Invoke Abinit mode")
parser.add_option("--band", dest="band_paths",
                  action="store", type="string",
                  help="Band paths in reduced coordinates")
parser.add_option("--band_points", dest="band_points", type="int",
                   help="Number of sampling points in a segment of band path")
parser.add_option("-c", "--cell", dest="cell_filename",
                  action="store", type="string",
                  help="Read unit cell", metavar="FILE")
parser.add_option("--color", dest="color_scheme",
                  action="store", type="string",
                  help="Color scheme")
parser.add_option("--cutoff", dest="cutoff_frequency", type="float",
                  help="Plot above this cutoff frequency for mesh sampling mode.")
parser.add_option("--dim", dest="supercell_dimension",
                  action="store", type="string",
                  help="Same behavior as DIM tag")
parser.add_option("--factor", dest="factor", type="float",
                  help="Conversion factor to favorite frequency unit")
parser.add_option("--hdf5", dest="is_hdf5", action="store_true",
                  help="Use hdf5 to read force constants and store results")
parser.add_option("--gc", "--gamma_center", dest="is_gamma_center",
                  action="store_true",
                  help="Set mesh as Gamma center")
parser.add_option("--marker", dest="marker",
                  action="store", type="string",
                  help="Marker for plot (matplotlib)")
parser.add_option("--markersize", dest="markersize", type="float",
                  help="Markersize for plot in points (matplotlib)")
parser.add_option("--mp", "--mesh", dest="sampling_mesh",
                  action="store", type="string",
                  help="Sampling mesh")
parser.add_option("--nac", dest="is_nac",
                  action="store_true",
                  help="Non-analytical term correction")
parser.add_option("--nomeshsym", dest="is_mesh_symmetry",
                  action="store_false",
                  help="Symmetry is not imposed for mesh sampling.")
parser.add_option("-o", "--output", dest="output_filename",
                  action="store", type="string",
                  help="Output filename of PDF plot")
parser.add_option("-p", "--plot", dest="plot_graph",
                  action="store_true",
                  help="Plot data")
parser.add_option("--pa", "--primitive_axis", dest="primitive_axis",
                  action="store", type="string",
                  help="Same as PRIMITIVE_AXIS tags")
parser.add_option("--pwscf", dest="pwscf_mode",
                  action="store_true", help="Invoke Pwscf mode")
parser.add_option("--q_cutoff", dest="cutoff_wave_vector", type="float",
                  help="Acoustic modes inside cutoff wave vector is treated.")
parser.add_option("--readfc", dest="reads_force_constants",
                  action="store_true",
                  help="Read FORCE_CONSTANTS")
parser.add_option("-s", "--save", dest="save_graph",
                  action="store_true",
                  help="Save plot data in pdf")
parser.add_option("-t", "--title", dest="title", action="store",
                  type="string", help="Title of plot")
parser.add_option("--tmax", dest="tmax", type="float",
                  help="Maximum calculated temperature")
parser.add_option("--tmin", dest="tmin", type="float",
                  help="Minimum calculated temperature")
parser.add_option("--tolerance", dest="symprec", type="float",
                  help="Symmetry tolerance to search")
parser.add_option("--tstep", dest="tstep", type="float",
                  help="Calculated temperature step")
parser.add_option("--vf", "--volumes_filename", dest="volumes_filename",
                  action="store", type="string",
                  help="Filename of volume is contained.")
parser.add_option("--wien2k", dest="wien2k_mode",
                  action="store_true", help="Invoke Wien2k mode")
parser.add_option("--yaml", dest="yaml_mode",
                  action="store_true", help="Activate phonopy YAML mode")
(options, args) = parser.parse_args()

if options.is_hdf5:
    try:
        import h5py
    except ImportError:
        print_error_message("You need to install python-h5py.")
        if log_level > 0:
            print_end()
        sys.exit(1)

if len(args) is not 3:
    sys.stderr.write("Three directory names (original, plus, minus) "
                     "have to be spefied.\n")
    sys.exit(1)

if options.primitive_axis:
    primitive_axis = np.array(
        [fracval(x) for x in options.primitive_axis.split()]).reshape(3, 3)
else:
    primitive_axis = np.eye(3)

if options.supercell_dimension:
    dim = [int(x) for x in options.supercell_dimension.split()]
    if len(dim) == 9:
        dim = np.array(dim).reshape(3, 3)
    elif len(dim) == 3:
        dim = np.diag(dim)
    else:
        print("Number of elements of --dim option has to be 3 or 9.")
        sys.exit(1)
    if np.linalg.det(dim) < 1:
        print("Determinant of supercell matrix has to be positive.")
        sys.exit(1)
else:
    print("--dim option has to be specified.")
    sys.exit(1)


#
# Phonopy interface mode
#
# Physical units: energy,  distance,  atomic mass, force
# vasp          : eV,      Angstrom,  AMU,         eV/Angstrom
# wien2k        : Ry,      au,        AMU,         mRy/au
# pwscf         : Ry,      au,        AMU,         Ry/au
# abinit        : hartree, au(=bohr), AMU,         eV/Angstrom
#
if options.wien2k_mode:
    interface_mode = 'wien2k'
    factor = Wien2kToTHz
elif options.abinit_mode:
    interface_mode = 'abinit'
    factor = AbinitToTHz
elif options.pwscf_mode:
    interface_mode = 'pwscf'
    factor = PwscfToTHz
else:
    interface_mode = 'vasp'
    factor = VaspToTHz

if options.factor is not None:
    factor = options.factor
    
if options.cell_filename:
    cell_filename = options.cell_filename
else:
    cell_filename = get_default_cell_filename(interface_mode)
    
    
phonons = []
for i in range(3):
    directory = args[i]
    if options.is_nac:
        born_filename = "%s/BORN" % directory
    else:
        born_filename = None
    
    if options.reads_force_constants:
        if options.is_hdf5:
            phonon = get_phonon_from_force_constants(
                "%s/%s" % (directory, cell_filename),
                "%s/force_constants.hdf5" % directory,
                dim,
                primitive_matrix=primitive_axis,
                born_filename=born_filename,
                factor=factor,
                interface_mode=interface_mode,
                symprec=options.symprec,
                is_hdf5=True)
        else:
            phonon = get_phonon_from_force_constants(
                "%s/%s" % (directory, cell_filename),
                "%s/FORCE_CONSTANTS" % directory,
                dim,
                primitive_matrix=primitive_axis,
                born_filename=born_filename,
                factor=factor,
                interface_mode=interface_mode,
                symprec=options.symprec)
        phonons.append(phonon)
    else:
        phonon = get_phonon("%s/%s" % (directory, cell_filename),
                            "%s/FORCE_SETS" % directory,
                            dim,
                            primitive_matrix=primitive_axis,
                            born_filename=born_filename,
                            factor=factor,
                            interface_mode=interface_mode,
                            symprec=options.symprec)
        phonons.append(phonon)

gruneisen = PhonopyGruneisen(phonons[0], # equilibrium
                             phonons[1], # plus
                             phonons[2]) # minus

if options.plot_graph:
    if options.save_graph:
        import matplotlib as mpl
        mpl.use('Agg')
        
if options.band_paths:
    band_paths = []
    for path_str in options.band_paths.split(','):
        paths = np.array(
            [fracval(x) for x in path_str.split()]).reshape(-1, 3)
        for i in range(len(paths) - 1):
            band_paths.append([paths[i], paths[i + 1]])

    gruneisen.set_band_structure(band_paths,
                                 options.band_points)
    gruneisen.write_yaml_band_structure()
    if options.plot_graph:
        plt = gruneisen.plot_band_structure(epsilon=options.cutoff_wave_vector,
                                            color_scheme=options.color_scheme)
        if options.title is not None:
            plt.suptitle(options.title)


elif options.sampling_mesh:
    mesh_numbers = np.array([int(x) for x in options.sampling_mesh.split()])
    gruneisen.set_mesh(mesh_numbers,
                       is_gamma_center=options.is_gamma_center,
                       is_mesh_symmetry=options.is_mesh_symmetry)

    if options.is_hdf5:
        gruneisen.write_hdf5_mesh()
    else:
        gruneisen.write_yaml_mesh()
    
    if options.plot_graph:
        plt = gruneisen.plot_mesh(cutoff_frequency=options.cutoff_frequency,
                                  color_scheme=options.color_scheme,
                                  marker=options.marker,
                                  markersize=options.markersize)
        if options.title is not None:
            plt.suptitle(options.title)

    if options.volumes_filename is not None:
        volumes = []
        with open(options.volumes_filename) as f:
            for line in f:
                if line.strip()[0] == '#':
                    continue
                volumes.append(float(line.split()[0]))
        if volumes:
            gruneisen.set_thermal_properties(
                volumes,
                t_step=options.tstep,
                t_max=options.tmax,
                t_min=options.tmin,
                cutoff_frequency=options.cutoff_frequency)
            gruneisen.write_yaml_thermal_properties()
else:
    pass

if options.plot_graph:
    if options.save_graph:
        if options.output_filename:
            plt.savefig(options.output_filename)
        else:
            plt.savefig("gruneisen.pdf")
    else:
        plt.show()
