/*
 * Decompiled with CFR 0.152.
 */
package sktree.tree;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.OpType;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;

public class ProjectionManager {
    private Map<List<Vector>, Feature> projections = new LinkedHashMap<List<Vector>, Feature>();

    public Feature getOrCreateFeature(String name, List<Feature> features, List<Number> weights, SkLearnEncoder encoder) {
        ClassDictUtil.checkSize((Collection[])new Collection[]{features, weights});
        List<Vector> key = ProjectionManager.createKey(features, weights);
        if (key.isEmpty()) {
            return null;
        }
        if (this.projections.containsKey(key)) {
            return this.projections.get(key);
        }
        Feature feature = ProjectionManager.encodeFeature(name, key, encoder);
        this.projections.put(key, feature);
        return feature;
    }

    private static Feature encodeFeature(String name, List<Vector> key, SkLearnEncoder encoder) {
        Object expression;
        ArrayList<Expression> plusExpressions = new ArrayList<Expression>();
        ArrayList<Expression> minusExpressions = new ArrayList<Expression>();
        for (int i = 0; i < key.size(); ++i) {
            FieldRef expression2;
            Vector vector = key.get(i);
            Feature feature = vector.getFeature();
            Number weight = vector.getWeight();
            if (key.size() == 1 && weight.doubleValue() == 1.0) {
                return feature;
            }
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                expression2 = new NormDiscrete(binaryFeature.getName(), binaryFeature.getValue());
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                expression2 = continuousFeature.ref();
            }
            if (weight.doubleValue() == 1.0) {
                plusExpressions.add((Expression)expression2);
                continue;
            }
            if (weight.doubleValue() == -1.0) {
                minusExpressions.add((Expression)expression2);
                continue;
            }
            throw new IllegalArgumentException();
        }
        Expression plusExpression = ProjectionManager.aggregate(plusExpressions);
        Expression minusExpression = ProjectionManager.aggregate(minusExpressions);
        if (plusExpression != null) {
            expression = minusExpression != null ? ExpressionUtil.createApply((String)"-", (Expression[])new Expression[]{plusExpression, minusExpression}) : plusExpression;
        } else if (minusExpression != null) {
            expression = ExpressionUtil.toNegative((Expression)minusExpression);
        } else {
            throw new IllegalArgumentException();
        }
        DerivedField derivedField = encoder.createDerivedField(name, OpType.CONTINUOUS, DataType.FLOAT, expression);
        return new ContinuousFeature((PMMLEncoder)encoder, (Field)derivedField);
    }

    private static Expression aggregate(List<Expression> expressions) {
        if (expressions.isEmpty()) {
            return null;
        }
        if (expressions.size() == 1) {
            return (Expression)Iterables.getOnlyElement(expressions);
        }
        Apply apply = ExpressionUtil.createApply((String)"sum", (Expression[])new Expression[0]);
        apply.getExpressions().addAll(expressions);
        return apply;
    }

    private static List<Vector> createKey(List<Feature> features, List<Number> weights) {
        ArrayList<Vector> result = new ArrayList<Vector>();
        for (int i = 0; i < features.size(); ++i) {
            Feature feature = features.get(i);
            Number weight = weights.get(i);
            if (ValueUtil.isZero((Number)weight)) continue;
            result.add(new Vector(feature, weight));
        }
        return result;
    }

    private static class Vector {
        private Feature feature = null;
        private Number weight = null;

        private Vector(Feature feature, Number weight) {
            this.setFeature(feature);
            this.setWeight(weight);
        }

        public Feature getFeature() {
            return this.feature;
        }

        public void setFeature(Feature feature) {
            this.feature = Objects.requireNonNull(feature);
        }

        public Number getWeight() {
            return this.weight;
        }

        private void setWeight(Number weight) {
            this.weight = Objects.requireNonNull(weight);
        }

        public int hashCode() {
            int result = 0;
            result += 31 * result + Objects.hashCode(this.getFeature());
            result += 31 * result + Objects.hashCode(this.getWeight());
            return result;
        }

        public boolean equals(Object object) {
            if (object instanceof Vector) {
                Vector that = (Vector)object;
                return Objects.equals(this.getFeature(), that.getFeature()) && Objects.equals(this.getWeight(), that.getWeight());
            }
            return false;
        }
    }
}

