/*
 * Decompiled with CFR 0.152.
 */
package sklego.meta;

import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.ClassifierUtil;
import sklearn.Estimator;
import sklearn.HasEstimator;
import sklearn.Transformer;

public class EstimatorTransformer
extends Transformer
implements HasEstimator<Estimator> {
    public EstimatorTransformer(String module, String name) {
        super(module, name);
    }

    @Override
    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        ContinuousLabel label;
        Estimator estimator = this.getEstimator();
        String predictFunc = this.getPredictFunc();
        if (!estimator.isSupervised()) {
            throw new IllegalArgumentException();
        }
        switch (predictFunc) {
            case "predict": {
                break;
            }
            default: {
                throw new IllegalArgumentException(predictFunc);
            }
        }
        MiningFunction miningFunction = estimator.getMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: {
                List<?> categories = ClassifierUtil.getClasses(estimator);
                DataType dataType = TypeUtil.getDataType(categories, (DataType)DataType.STRING);
                label = new CategoricalLabel(null, dataType, categories);
                break;
            }
            case REGRESSION: {
                label = new ContinuousLabel(null, DataType.DOUBLE);
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        Schema schema = new Schema((PMMLEncoder)encoder, (Label)label, features);
        Model model = estimator.encode(schema);
        Output output = model.getOutput();
        if (output != null && output.hasOutputFields()) {
            List outputFields = output.getOutputFields();
            outputFields.clear();
        }
        encoder.addTransformer(model);
        FieldName name = this.createFieldName("estimator", new Object[0]);
        switch (miningFunction) {
            case CLASSIFICATION: {
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                OutputField predictedOutputField = ModelUtil.createPredictedField((FieldName)name, (OpType)OpType.CATEGORICAL, (DataType)categoricalLabel.getDataType());
                DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);
                return Collections.singletonList(new CategoricalFeature((PMMLEncoder)encoder, (Field)predictedField, categoricalLabel.getValues()));
            }
            case REGRESSION: {
                ContinuousLabel continuousLabel = label;
                OutputField predictedOutputField = ModelUtil.createPredictedField((FieldName)name, (OpType)OpType.CONTINUOUS, (DataType)continuousLabel.getDataType());
                DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);
                return Collections.singletonList(new ContinuousFeature((PMMLEncoder)encoder, (Field)predictedField));
            }
        }
        throw new IllegalArgumentException();
    }

    @Override
    public Estimator getEstimator() {
        return (Estimator)this.get("estimator_", Estimator.class);
    }

    public String getPredictFunc() {
        return this.getString("predict_func");
    }
}

