#!python
import argparse

from ase.db import connect
import numpy as np
from tqdm import tqdm


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Set units of an ASE dataset, e.g. to convert from SchNetPack 1.0 to the new format."
    )
    parser.add_argument(
        "data_path",
        help="Path to ASE DB dataset",
    )
    parser.add_argument(
        "--distunit",
        help="Distance unit as string, corresponding to ASE units (e.g. `Ang`)",
    )
    parser.add_argument(
        "--propunit",
        help="Property units as string, corresponding "
        "to ASE units (e.g. `kcal/mol/Ang`), in the form: `property1:unit1,property2:unit2`",
    )
    parser.add_argument(
        "--expand_property_dims",
        default=[],
        nargs='+',
        help="Expanding the first dimension of the given property "
             "(required for example for old FieldSchNet datasets). "
             "Add property names here in the form 'property1 property2 property3'",
    )
    args = parser.parse_args()
    with connect(args.data_path) as db:
        meta = db.metadata
    print(meta)

    if "atomrefs" not in meta.keys():
        meta["atomrefs"] = {}
    elif "atref_labels" in meta.keys():
        old_atref = np.array(meta["atomrefs"])
        new_atomrefs = {}
        labels = meta["atref_labels"]
        if type(labels) is str:
            labels = [labels]
        for i, label in enumerate(labels):
            print(i, label, old_atref[:, i])
            new_atomrefs[label] = list(old_atref[:, i])
        meta["atomrefs"] = new_atomrefs
        del meta["atref_labels"]

    if args.distunit:
        if args.distunit == "A":
            raise ValueError(
                "The provided unit (A for Ampere) is not a valid distance unit according to the ASE unit"
                " definitions. You probably mean `Ang`/`Angstrom`. Please also check your property units!"
            )
        meta["_distance_unit"] = args.distunit

    if args.propunit:
        if "_property_unit_dict" not in meta.keys():
            meta["_property_unit_dict"] = {}

        for p in args.propunit.split(","):
            prop, unit = p.split(":")
            meta["_property_unit_dict"][prop] = unit

    with connect(args.data_path) as db:
        db.metadata = meta

    if args.expand_property_dims is not None and len(args.expand_property_dims) > 0:

        with connect(args.data_path) as db:
            for i in tqdm(range(len(db))):
                atoms_row = db.get(i + 1)
                data = {}
                for p, v in atoms_row.data.items():
                    if p in args.expand_property_dims:
                        data[p] = np.expand_dims(v, 0)
                    else:
                        data[p] = v
                db.update(i + 1, data=data)