#!python
import torch
import torch.nn as nn
from schnetpack.transform import CastTo64, CastTo32, AddOffsets
import argparse


# This script is supposed to take a pytorch model and save a just in time compiled version of it.
# This is needed to run the model with LAMMPS.
# For further info see examples/howtos/lammps.rst

# Note that this script is designed for models that predict atomic forces via automatic differentiation (utilizing response modules). 
# Hence this script will not work for models without response modules.


def get_jit_model(model):
    # fix invalid operations in postprocessing
    jit_postprocessors = nn.ModuleList()
    for postprocessor in model.postprocessors:
        # ignore type casting
        if type(postprocessor) in [CastTo64, CastTo32]:
            continue
        # ensure offset mean is float
        if type(postprocessor) == AddOffsets:
            postprocessor.mean = postprocessor.mean.float()

        jit_postprocessors.append(postprocessor)
    model.postprocessors = jit_postprocessors

    return torch.jit.script(model)


def save_jit_model(model, model_path):
    jit_model = get_jit_model(model)

    # add metadata
    metadata = dict()
    metadata["cutoff"] = str(jit_model.representation.cutoff.item()).encode("ascii")

    torch.jit.save(jit_model, model_path, _extra_files=metadata)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("model_path")
    parser.add_argument("deployed_model_path")
    parser.add_argument("--device", type=str, default="cpu")
    args = parser.parse_args()

    model = torch.load(args.model_path, map_location=args.device)
    save_jit_model(model, args.deployed_model_path)

    print(f"stored deployed model at {args.deployed_model_path}.")
