/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.model.deep;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.deep.DeepModel;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.util.FileUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.StringUtils;

public class DeepModelPredicate
extends DeepModel {
    private static final Logger log = Logger.getLogger(DeepModelPredicate.class);
    private static final String DELIM = "\t";
    public static final String CONFIG_ENTITY_DATA_MAP_PATH = "entity-data-map-path";
    public static final String CONFIG_ENTITY_ARGUMENT_INDEXES = "entity-argument-indexes";
    public static final String CONFIG_CLASS_SIZE = "class-size";
    private AtomStore atomStore = null;
    private Predicate predicate;
    private int classSize;
    private int[] atomIndexes;
    private int[] dataIndexes;
    private float[] gradients;
    private float[] symbolicGradients;
    private ArrayList<Integer> validAtomIndexes;
    private ArrayList<Integer> validDataIndexes;

    public DeepModelPredicate(Predicate predicate) {
        super("DeepModelPredicate");
        this.predicate = predicate;
        this.classSize = -1;
        this.atomIndexes = null;
        this.dataIndexes = null;
        this.gradients = null;
        this.symbolicGradients = null;
        this.validAtomIndexes = new ArrayList();
        this.validDataIndexes = new ArrayList();
    }

    public DeepModelPredicate copy() {
        DeepModelPredicate copy = new DeepModelPredicate(this.predicate);
        copy.pythonOptions = this.pythonOptions;
        copy.application = this.application;
        DeepModelPredicate.freePort(copy.port);
        copy.port = this.port;
        copy.pythonModule = this.pythonModule;
        copy.sharedMemoryPath = this.sharedMemoryPath;
        copy.pythonServerProcess = this.pythonServerProcess;
        copy.sharedFile = this.sharedFile;
        copy.sharedBuffer = this.sharedBuffer;
        copy.socket = this.socket;
        copy.socketInput = this.socketInput;
        copy.socketOutput = this.socketOutput;
        copy.serverOpen = this.serverOpen;
        copy.atomStore = this.atomStore;
        copy.classSize = this.classSize;
        copy.atomIndexes = null;
        if (this.atomIndexes != null) {
            copy.atomIndexes = Arrays.copyOf(this.atomIndexes, this.atomIndexes.length);
        }
        copy.dataIndexes = null;
        if (this.dataIndexes != null) {
            copy.dataIndexes = Arrays.copyOf(this.dataIndexes, this.dataIndexes.length);
        }
        copy.validAtomIndexes = new ArrayList(this.validAtomIndexes.size());
        copy.validAtomIndexes.addAll(this.validAtomIndexes);
        copy.validDataIndexes = new ArrayList(this.validDataIndexes.size());
        copy.validDataIndexes.addAll(this.validDataIndexes);
        copy.gradients = null;
        if (this.gradients != null) {
            copy.gradients = Arrays.copyOf(this.gradients, this.gradients.length);
        }
        copy.symbolicGradients = null;
        if (this.symbolicGradients != null) {
            copy.symbolicGradients = Arrays.copyOf(this.symbolicGradients, this.symbolicGradients.length);
        }
        return copy;
    }

    @Override
    public int init() {
        int i;
        log.debug("Initializing deep model predicate: {}", this.predicate.getName());
        this.validateOptions();
        this.classSize = Integer.parseInt((String)this.pythonOptions.get(CONFIG_CLASS_SIZE));
        String entityDataMapPath = FileUtils.makePath((String)this.pythonOptions.get("relative-dir"), (String)this.pythonOptions.get(CONFIG_ENTITY_DATA_MAP_PATH));
        int numEntityArgs = StringUtils.splitInt((String)this.pythonOptions.get(CONFIG_ENTITY_ARGUMENT_INDEXES), ",").length;
        int maxDataIndex = this.mapEntitiesFromFileToAtoms(entityDataMapPath, this.atomStore, numEntityArgs);
        this.atomIndexes = new int[this.validAtomIndexes.size()];
        this.gradients = new float[this.validAtomIndexes.size()];
        this.dataIndexes = new int[this.validDataIndexes.size()];
        for (i = 0; i < this.atomIndexes.length; ++i) {
            this.atomIndexes[i] = this.validAtomIndexes.get(i);
            this.gradients[i] = 0.0f;
        }
        for (i = 0; i < this.dataIndexes.length; ++i) {
            this.dataIndexes[i] = this.validDataIndexes.get(i);
        }
        this.validAtomIndexes.clear();
        this.validDataIndexes.clear();
        return 32 + maxDataIndex * 32 + maxDataIndex * this.classSize * 32;
    }

    @Override
    public void writeFitData() {
        log.debug("Writing fit data for deep model predicate: {}", this.predicate.getName());
        for (int index = 0; index < this.gradients.length; ++index) {
            this.gradients[index] = this.symbolicGradients[this.atomIndexes[index]];
        }
        this.writeDataIndexData();
        this.writeGradientData(this.gradients);
    }

    @Override
    public void writePredictData() {
        log.debug("Writing predict data for deep model predicate: {}", this.predicate.getName());
        this.writeDataIndexData();
    }

    @Override
    public float readPredictData() {
        log.debug("Reading predict data for deep model predicate: {}", this.predicate.getName());
        int count = this.sharedBuffer.getInt();
        if (count != this.atomIndexes.length) {
            throw new RuntimeException(String.format("External model did not make the desired number of predictions, got %d, expected %d.", count, this.atomIndexes.length));
        }
        float[] atomValues = this.atomStore.getAtomValues();
        float deepPrediction = 0.0f;
        int atomIndex = 0;
        float change = 0.0f;
        for (int index = 0; index < this.atomIndexes.length; ++index) {
            deepPrediction = this.sharedBuffer.getFloat();
            atomIndex = this.atomIndexes[index];
            change += Math.abs(atomValues[atomIndex] - deepPrediction);
            atomValues[atomIndex] = deepPrediction;
            ((RandomVariableAtom)this.atomStore.getAtom(atomIndex)).setValue(deepPrediction);
        }
        return change;
    }

    @Override
    public void writeEvalData() {
        log.debug("Writing eval data for deep model predicate: {}", this.predicate.getName());
        this.writeDataIndexData();
    }

    @Override
    public void close() {
        super.close();
        this.classSize = -1;
        this.atomIndexes = null;
        this.dataIndexes = null;
        this.gradients = null;
        this.symbolicGradients = null;
        this.validAtomIndexes.clear();
        this.validDataIndexes.clear();
    }

    public void setAtomStore(AtomStore atomStore) {
        this.setAtomStore(atomStore, false);
    }

    public void setAtomStore(AtomStore atomStore, boolean init) {
        this.atomStore = atomStore;
        if (init) {
            this.init();
        }
    }

    public void setSymbolicGradients(float[] symbolicGradients) {
        this.symbolicGradients = symbolicGradients;
    }

    private void validateOptions() {
        for (Map.Entry<String, Object> entry : this.predicate.getPredicateOptions().entrySet()) {
            this.pythonOptions.put(entry.getKey(), (String)entry.getValue());
        }
        if (FileUtils.makePath((String)this.pythonOptions.get("relative-dir"), (String)this.pythonOptions.get(CONFIG_ENTITY_DATA_MAP_PATH)) == null) {
            throw new IllegalArgumentException(String.format("A DeepPredicate must have an entity to data map path (\"%s\") specified in predicate config.", CONFIG_ENTITY_DATA_MAP_PATH));
        }
        if (this.pythonOptions.get(CONFIG_ENTITY_ARGUMENT_INDEXES) == null) {
            throw new IllegalArgumentException(String.format("A DeepPredicate must have entity argument indexes (\"%s\") specified in predicate config.", CONFIG_ENTITY_ARGUMENT_INDEXES));
        }
        if (this.pythonOptions.get(CONFIG_CLASS_SIZE) == null) {
            throw new IllegalArgumentException(String.format("A DeepPredicate must have a class size (\"%s\") specified in predicate config.", CONFIG_CLASS_SIZE));
        }
        for (Map.Entry<String, Object> entry : this.pythonOptions.entrySet()) {
            if (!entry.getKey().contains(this.application + "::")) continue;
            String[] optionParts = entry.getKey().split("::");
            this.pythonOptions.put(optionParts[1], entry.getValue());
        }
    }

    private int mapEntitiesFromFileToAtoms(String filePath, AtomStore atomStore, int numEntityArgs) {
        Constant[] arguments = new Constant[numEntityArgs + 1];
        String line = null;
        int lineNumber = 0;
        int atomIndex = 0;
        int dataIndex = 0;
        try (BufferedReader reader = FileUtils.getBufferedReader(filePath);){
            while ((line = reader.readLine()) != null) {
                ConstantType type;
                int index;
                ++lineNumber;
                if ((line = line.trim()).isEmpty()) continue;
                String[] parts = line.split(DELIM);
                if (parts.length < numEntityArgs) {
                    throw new RuntimeException(String.format("Entity found on line (%d) must contain %d arguments for predicate %s.", lineNumber, numEntityArgs, this.predicate.getName()));
                }
                for (index = 0; index < arguments.length - 1; ++index) {
                    type = this.predicate.getArgumentType(index);
                    arguments[index] = ConstantType.getConstant(parts[index], type);
                }
                type = this.predicate.getArgumentType(arguments.length - 1);
                for (index = 0; index < this.classSize; ++index) {
                    arguments[arguments.length - 1] = ConstantType.getConstant(String.valueOf(index), type);
                    atomIndex = atomStore.getAtomIndex(this.predicate, arguments);
                    if (atomIndex == -1) break;
                    this.validAtomIndexes.add(atomIndex);
                }
                if (this.validAtomIndexes.size() % this.classSize != 0) {
                    throw new RuntimeException(String.format("Entity found on line (%d) has unspecified class values for predicate %s.", lineNumber, this.predicate.getName()));
                }
                if (atomIndex != -1) {
                    this.validDataIndexes.add(dataIndex);
                }
                ++dataIndex;
            }
        }
        catch (IOException ex) {
            throw new RuntimeException("Unable to parse entity data map file: " + filePath, ex);
        }
        return dataIndex;
    }

    private void writeGradientData(float[] data) {
        for (int i = 0; i < data.length; ++i) {
            this.sharedBuffer.putFloat(data[i]);
        }
    }

    private void writeDataIndexData() {
        this.sharedBuffer.putInt(this.dataIndexes.length);
        for (int i = 0; i < this.dataIndexes.length; ++i) {
            this.sharedBuffer.putInt(this.dataIndexes[i]);
        }
    }
}

