/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.h2o.Converter;
import org.jpmml.h2o.XGBoostMojoModel;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.XGBoostUtil;

public class XGBoostMojoModelConverter
extends Converter<XGBoostMojoModel> {
    public XGBoostMojoModelConverter(XGBoostMojoModel model) {
        super(model);
    }

    @Override
    public Schema toMojoModelSchema(Schema schema) {
        final ModelEncoder encoder = schema.getEncoder();
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        Function<Feature, Stream<Feature>> function = new Function<Feature, Stream<Feature>>(){

            @Override
            public Stream<Feature> apply(Feature feature) {
                if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    List values = categoricalFeature.getValues();
                    Stream<Feature> binaryFeaturesStream = values.stream().map(value -> new BinaryFeature((PMMLEncoder)encoder, (Feature)categoricalFeature, value));
                    Stream<MissingValueFeature> missingValueFeatureStream = Stream.of(new MissingValueFeature((PMMLEncoder)encoder, (Feature)categoricalFeature));
                    return Stream.concat(binaryFeaturesStream, missingValueFeatureStream);
                }
                return Stream.of(feature);
            }
        };
        features = features.stream().flatMap(function).collect(Collectors.toList());
        return new Schema(encoder, label, features);
    }

    public MiningModel encodeModel(Schema schema) {
        Learner learner;
        XGBoostMojoModel model = (XGBoostMojoModel)this.getModel();
        byte[] boosterBytes = model.getBoosterBytes();
        try (ByteArrayInputStream is = new ByteArrayInputStream(boosterBytes);){
            learner = XGBoostUtil.loadLearner((InputStream)is);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
        LinkedHashMap<String, Boolean> options = new LinkedHashMap<String, Boolean>();
        options.put("compact", Boolean.TRUE);
        options.put("numeric", Boolean.TRUE);
        Schema xgbSchema = learner.toXGBoostSchema(schema);
        return learner.encodeModel(options, xgbSchema);
    }
}

