/*
 * Decompiled with CFR 0.152.
 */
package statsmodels.tsa.arima;

import java.util.List;
import org.dmg.pmml.Array;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.OpType;
import org.dmg.pmml.time_series.InterceptVector;
import org.dmg.pmml.time_series.MeasurementMatrix;
import org.dmg.pmml.time_series.StateSpaceModel;
import org.dmg.pmml.time_series.StateVector;
import org.dmg.pmml.time_series.TimeSeriesModel;
import org.dmg.pmml.time_series.TransitionMatrix;
import org.jpmml.converter.CMatrix;
import org.jpmml.converter.Label;
import org.jpmml.converter.Matrix;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.python.HasArray;
import org.jpmml.statsmodels.StatsModelsEncoder;
import statsmodels.Results;
import statsmodels.tsa.TimeSeriesModel;
import statsmodels.tsa.statespace.SmootherResults;

public class ARIMA
extends TimeSeriesModel {
    private static final String NAME_INDEX = "index";

    public ARIMA(String module, String name) {
        super(module, name);
    }

    @Override
    public Schema encodeSchema(StatsModelsEncoder encoder) {
        Schema schema = super.encodeSchema(encoder);
        DataField dataField = encoder.createDataField("horizon", OpType.CONTINUOUS, DataType.INTEGER);
        return schema;
    }

    public org.dmg.pmml.time_series.TimeSeriesModel encodeModel(Results results, Schema schema) {
        HasArray predictedState = results.getArray("predicted_state");
        SmootherResults smootherResults = (SmootherResults)((Object)results.get("smoother_results", SmootherResults.class));
        HasArray design = smootherResults.getDesign();
        HasArray obsIntercept = smootherResults.getObsIntercept();
        HasArray transition = smootherResults.getTransition();
        MiningSchema miningSchema = ModelUtil.createMiningSchema((Label)schema.getLabel()).addMiningFields(new MiningField[]{ModelUtil.createMiningField((String)"horizon", (MiningField.UsageType)MiningField.UsageType.SUPPLEMENTARY)});
        StateVector stateVector = new StateVector(ARIMA.createRealArray(predictedState, -1));
        MeasurementMatrix measurementMatrix = new MeasurementMatrix(ARIMA.createMatrix(design));
        TransitionMatrix transitionMatrix = new TransitionMatrix(ARIMA.createMatrix(transition));
        InterceptVector interceptVector = new InterceptVector(ARIMA.createRealArray(obsIntercept, -1)).setType(InterceptVector.Type.OBSERVATION);
        StateSpaceModel stateSpaceModel = new StateSpaceModel().setStateVector(stateVector).setMeasurementMatrix(measurementMatrix).setTransitionMatrix(transitionMatrix).setInterceptVector(interceptVector);
        org.dmg.pmml.time_series.TimeSeriesModel timeSeriesModel = new org.dmg.pmml.time_series.TimeSeriesModel(MiningFunction.TIME_SERIES, TimeSeriesModel.Algorithm.STATE_SPACE_MODEL, miningSchema).setStateSpaceModel(stateSpaceModel);
        return timeSeriesModel;
    }

    private static Array createRealArray(HasArray hasArray, int column) {
        Matrix<?> matrix = ARIMA.toMatrix(hasArray);
        List columnValues = column >= 0 ? matrix.getColumnValues(column) : matrix.getColumnValues(matrix.getColumns() + column);
        return PMMLUtil.createRealArray((List)columnValues);
    }

    private static org.dmg.pmml.Matrix createMatrix(HasArray hasArray) {
        Matrix<?> matrix = ARIMA.toMatrix(hasArray);
        org.dmg.pmml.Matrix result = new org.dmg.pmml.Matrix().setNbRows(Integer.valueOf(matrix.getRows())).setNbCols(Integer.valueOf(matrix.getColumns()));
        for (int row = 0; row < matrix.getRows(); ++row) {
            List rowValues = matrix.getRowValues(row);
            result.addArrays(new Array[]{PMMLUtil.createRealArray((List)rowValues)});
        }
        return result;
    }

    private static Matrix<?> toMatrix(HasArray hasArray) {
        int[] shape = hasArray.getArrayShape();
        List values = hasArray.getArrayContent();
        if (shape.length == 3 && shape[2] != 1) {
            throw new IllegalArgumentException();
        }
        return new CMatrix(values, shape[0], shape[1]);
    }
}

