#!python

# Copyright (C) 2011 Atsushi Togo
# All rights reserved.
#
# This file is part of phonopy.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
#   notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
#   notice, this list of conditions and the following disclaimer in
#   the documentation and/or other materials provided with the
#   distribution.
#
# * Neither the name of the phonopy project nor the names of its
#   contributors may be used to endorse or promote products derived
#   from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

# Displacement manager ( dispmanager )
#
# Usage:
#   dispmanager disp.yaml

import sys
import os
import numpy as np
from phonopy.structure.atoms import PhonopyAtoms as Atoms
from phonopy.interface.vasp import write_vasp
import phonopy.file_IO as file_IO

try:
    import yaml
except ImportError:
    print("You need to install python-yaml.")
    exit(1)

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader


def get_options():
    # Parse options
    import argparse
    parser = argparse.ArgumentParser(
        description="Phonopy dispmanager command-line-tool")
    parser.set_defaults(output_filename=None,
                        is_overwrite=False,
                        is_create_structure_file=False,
                        is_d2d=False,
                        add_disp=None,
                        select_disp=None,
                        amplitude=0.01)
    parser.add_argument(
        "-o", "--output", dest="output_filename",
        help="Output filename")
    parser.add_argument(
        "--overwrite", dest="is_overwrite", action="store_true",
        help="Overwrite input file")
    parser.add_argument(
        "-w", dest="is_create_structure_file", action="store_true",
        help="Create structure files")
    parser.add_argument(
        "-a", "--add", dest="add_disp",
        help="Direction of added displacement")
    parser.add_argument(
        "-s", "--select", dest="select_disp",
        help="Select displacements and write input file")
    parser.add_argument(
        "--amplitude", dest="amplitude", type=float,
        help="Amplitude of displacement")
    parser.add_argument(
        "--d2d", dest="is_d2d", action="store_true",
        help="Show the order of calculated files for disp.yaml")
    parser.add_argument(
        "filename", nargs='*',
        help="Filename of displacement file (disp.yaml)")
    args = parser.parse_args()
    return args


def main(args):
    if len(args.filename) > 0:
        filename = args.filename[0]
    else:
        filename = 'disp.yaml'

    if os.path.exists(filename):
        disp = yaml.load(open(filename).read(), Loader=Loader)
    else:
        print("%s could not be found." % filename)
        sys.exit(1)

    displacements = []
    for x in disp['displacements']:
        atom = x['atom'] - 1
        d = x['displacement']
        displacements.append([atom, d[0], d[1], d[2]])
    lattice = disp['lattice']
    positions = [x['coordinates'] for x in disp['points']]
    symbols = [x['symbol'] for x in disp['points']]
    cell = Atoms(cell=lattice,
                 scaled_positions=positions,
                 symbols=symbols,
                 pbc=True)

    ######################
    # Create DPOSCAR-xxx #
    ######################
    if args.is_create_structure_file:
        for i, disp in enumerate(displacements):
            positions = cell.get_positions()
            positions[disp[0]] += disp[1:4]
            write_vasp("%s-%03d" % ("DPOSCAR", i + 1),
                       Atoms(numbers=cell.get_atomic_numbers(),
                             masses=cell.get_masses(),
                             positions=positions,
                             cell=cell.get_cell(),
                             pbc=True), direct=True)
        sys.exit(0)

    ####################
    # Modify disp.yaml #
    ####################
    if args.is_overwrite:
        output_filename = filename
    else:
        output_filename = args.output_filename

    # Add displacements
    if args.add_disp is not None:

        if output_filename is None:
            print("Output filename (-o or --overwrite) is required.")
            sys.exit(1)

        print("%s" % args.add_disp.split()[1:4])
        v = np.array([float(x) for x in args.add_disp.split()[1:4]])
        v = np.dot(v, cell.get_cell())
        v = v / np.linalg.norm(v) * args.amplitude
        d = [int(args.add_disp.split()[0]) - 1, v[0], v[1], v[2]]
        displacements.append(d)

        file_IO.write_disp_yaml(displacements,
                                cell,
                                filename=output_filename)
        sys.exit(0)

    # Select displacements
    if args.select_disp is not None:

        if output_filename is None:
            print("Output filename (-o or --overwrite) is required.")
            sys.exit(1)

        disp_selected = []
        for x in args.select_disp.split():
            disp_selected.append(displacements[int(x) - 1])
        file_IO.write_disp_yaml(disp_selected,
                                cell,
                                filename=output_filename)
        sys.exit(0)

    print("Nothing has been done.")


if __name__ == "__main__":
    main(get_options())
