/*
 * Decompiled with CFR 0.152.
 */
package org.pytorch.serve.snapshot;

import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.pytorch.serve.archive.DownloadArchiveException;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.model.ModelNotFoundException;
import org.pytorch.serve.servingsdk.snapshot.Snapshot;
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer;
import org.pytorch.serve.snapshot.InvalidSnapshotException;
import org.pytorch.serve.snapshot.SnapshotReadException;
import org.pytorch.serve.snapshot.SnapshotSerializerFactory;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class SnapshotManager {
    private static final Logger logger = LoggerFactory.getLogger(SnapshotManager.class);
    private static SnapshotManager snapshotManager;
    private ConfigManager configManager;
    private ModelManager modelManager;
    private SnapshotSerializer snapshotSerializer;

    public static void init(ConfigManager configManager) {
        snapshotManager = new SnapshotManager(configManager);
    }

    public static SnapshotManager getInstance() {
        return snapshotManager;
    }

    private SnapshotManager(ConfigManager configManager) {
        this.configManager = configManager;
        this.modelManager = ModelManager.getInstance();
        this.snapshotSerializer = SnapshotSerializerFactory.getSerializer();
    }

    private void saveSnapshot(String snapshotName) {
        if (this.configManager.isSnapshotDisabled()) {
            return;
        }
        Map<String, Model> defModels = this.modelManager.getDefaultModels(true);
        HashMap<String, Map<String, JsonObject>> modelNameMap = new HashMap<String, Map<String, JsonObject>>();
        try {
            int modelCount = 0;
            for (Map.Entry<String, Model> m3 : defModels.entrySet()) {
                if (m3.getValue().isWorkflowModel()) continue;
                Set<Map.Entry<String, Model>> versionModels = this.modelManager.getAllModelVersions(m3.getKey());
                HashMap<String, JsonObject> modelInfoMap = new HashMap<String, JsonObject>();
                for (Map.Entry<String, Model> versionedModel : versionModels) {
                    String version = String.valueOf(versionedModel.getKey());
                    boolean isDefaultVersion = m3.getValue().getVersion().equals(versionedModel.getValue().getVersion());
                    modelInfoMap.put(version, versionedModel.getValue().getModelState(isDefaultVersion));
                    ++modelCount;
                }
                modelNameMap.put(m3.getKey(), modelInfoMap);
            }
            Snapshot snapshot = new Snapshot(snapshotName, modelCount);
            snapshot.setModels(modelNameMap);
            this.snapshotSerializer.saveSnapshot(snapshot, this.configManager.getConfiguration());
        }
        catch (ModelNotFoundException e) {
            logger.error("Model not found while saving snapshot {}", (Object)snapshotName);
        }
        catch (IOException e) {
            logger.error("Error while saving snapshot to file {}", (Object)snapshotName);
        }
    }

    public void saveSnapshot() {
        this.saveSnapshot(this.getSnapshotName("snapshot"));
    }

    public void saveStartupSnapshot() {
        this.saveSnapshot(this.getSnapshotName("startup"));
    }

    public void saveShutdownSnapshot() {
        this.saveSnapshot(this.getSnapshotName("shutdown"));
    }

    public Snapshot getSnapshot(String snapshotName) throws SnapshotReadException {
        try {
            return this.snapshotSerializer.getSnapshot(snapshotName);
        }
        catch (IOException e) {
            throw new SnapshotReadException("Error while retrieving snapshot details. Cause : " + e.getCause());
        }
    }

    public void restore(String modelSnapshot) throws InvalidSnapshotException, IOException {
        logger.info("Started restoring models from snapshot {}", (Object)modelSnapshot);
        Snapshot snapshot = this.snapshotSerializer.getSnapshot(modelSnapshot);
        this.validate(snapshot);
        this.initModels(snapshot);
    }

    private void initModels(Snapshot snapshot) {
        try {
            Map<String, Map<String, JsonObject>> models = snapshot.getModels();
            if (snapshot.getModelCount() <= 0) {
                logger.warn("Model snapshot is empty. Starting TorchServe without initial models.");
                return;
            }
            for (Map.Entry<String, Map<String, JsonObject>> modelMap : models.entrySet()) {
                String modelName = modelMap.getKey();
                for (Map.Entry<String, JsonObject> versionModel : modelMap.getValue().entrySet()) {
                    JsonObject modelInfo = versionModel.getValue();
                    this.modelManager.registerAndUpdateModel(modelName, modelInfo);
                }
            }
        }
        catch (IOException e) {
            logger.error("Error while retrieving snapshot details. Details: {}", (Object)e.getMessage());
        }
        catch (InterruptedException | DownloadArchiveException | ModelException | WorkerInitializationException e) {
            logger.error("Error while registering model. Details: {}", (Object)e.getMessage());
        }
    }

    private boolean validate(Snapshot snapshot) throws IOException, InvalidSnapshotException {
        logger.info("Validating snapshot {}", (Object)snapshot.getName());
        String modelStore = this.configManager.getModelStore();
        Map<String, Map<String, JsonObject>> models = snapshot.getModels();
        for (Map.Entry<String, Map<String, JsonObject>> modelMap : models.entrySet()) {
            String modelName = modelMap.getKey();
            for (Map.Entry<String, JsonObject> versionModel : modelMap.getValue().entrySet()) {
                String versionId = versionModel.getKey();
                String marName = versionModel.getValue().get("marName").getAsString();
                File marFile = new File(modelStore + "/" + marName);
                if (marFile.exists()) continue;
                logger.error("Model archive file for model {}, version {} not found in model store", (Object)modelName, (Object)versionId);
                throw new InvalidSnapshotException("Model archive file for model :" + modelName + ", version :" + versionId + " not found in model store");
            }
        }
        logger.info("Snapshot {} validated successfully", (Object)snapshot.getName());
        return true;
    }

    private String getSnapshotName(String snapshotType) {
        return new SimpleDateFormat("yyyyMMddHHmmssSSS'-" + snapshotType + ".cfg'").format(new Date());
    }
}

