/*
 * Decompiled with CFR 0.152.
 */
package pycaret.pipeline;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.sklearn.SkLearnEncoder;
import pycaret.preprocess.TransformerWrapper;
import sklearn.Estimator;
import sklearn.pipeline.SkLearnPipeline;

public class PyCaretPipeline
extends SkLearnPipeline {
    public PyCaretPipeline(String module, String name) {
        super(module, name);
    }

    public int getNumberOfFeatures() {
        return -1;
    }

    public List<? extends TransformerWrapper> getTransformers() {
        List transformers = super.getTransformers();
        return Lists.transform((List)transformers, TransformerWrapper.class::cast);
    }

    public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder) {
        ArrayList result = super.encodeFeatures(features, encoder);
        Label label = encoder.getLabel();
        if (label != null) {
            result = new ArrayList(result);
            List scalarLabels = ScalarLabelUtil.toScalarLabels((Label)label);
            for (ScalarLabel scalarLabel : scalarLabels) {
                Feature labelFeature = FeatureUtil.findLabelFeature(result, (ScalarLabel)scalarLabel);
                if (labelFeature == null) continue;
                result.remove(labelFeature);
            }
        }
        return result;
    }

    public Model encodeModel(Schema schema) {
        return super.encodeModel(schema);
    }

    public PMML encodePMML() {
        SkLearnEncoder encoder = new SkLearnEncoder();
        List<? extends TransformerWrapper> transformers = this.getTransformers();
        Estimator estimator = this.getFinalEstimator();
        TransformerWrapper transformer = transformers.get(0);
        String targetName = transformer.getTargetName();
        if (targetName != null) {
            encoder.initLabel(estimator, Collections.singletonList(targetName));
        }
        Schema schema = encoder.createSchema();
        Model model = this.encodeModel(schema);
        encoder.setModel(model);
        return encoder.encodePMML(model);
    }

    public Label refreshLabel(Label label, SkLearnEncoder encoder) {
        ScalarLabel scalarLabel;
        if (label instanceof ScalarLabel && !(scalarLabel = (ScalarLabel)label).isAnonymous()) {
            DataField dataField = (DataField)encoder.getField(scalarLabel.getName());
            return ScalarLabelUtil.createScalarLabel((DataField)dataField);
        }
        return label;
    }
}

