/*
 * Decompiled with CFR 0.152.
 */
package sklearn.tree.visitors;

import java.util.List;
import java.util.Objects;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.visitors.AbstractTreeModelTransformer;
import org.jpmml.model.UnsupportedAttributeException;

public class TreeModelPruner
extends AbstractTreeModelTransformer {
    private MiningFunction miningFunction = null;

    public VisitorAction visit(Node node) {
        if (node.hasScoreDistributions()) {
            return VisitorAction.SKIP;
        }
        return super.visit(node);
    }

    public void enterNode(Node node) {
        super.enterNode(node);
    }

    public void exitNode(Node node) {
        List children;
        Object childScore;
        Object score = node.getScore();
        if (node.hasScoreDistributions()) {
            return;
        }
        if (node.hasNodes() && (childScore = TreeModelPruner.getConstantScore(children = node.getNodes())) != null) {
            if (score == null) {
                node.setScore(childScore);
            }
            if (Objects.equals(score, childScore)) {
                children.clear();
            }
        }
    }

    public void enterTreeModel(TreeModel treeModel) {
        super.enterTreeModel(treeModel);
        MiningFunction miningFunction = treeModel.requireMiningFunction();
        switch (miningFunction) {
            case CLASSIFICATION: 
            case REGRESSION: {
                break;
            }
            default: {
                throw new UnsupportedAttributeException((PMMLObject)treeModel, (Enum)miningFunction);
            }
        }
        this.miningFunction = miningFunction;
    }

    public void exitTreeModel(TreeModel treeModel) {
        super.exitTreeModel(treeModel);
        this.miningFunction = null;
    }

    private static Object getConstantScore(List<Node> nodes) {
        Object result = null;
        for (Node node : nodes) {
            Object score = node.getScore();
            if (score == null || node.hasScoreDistributions()) {
                return null;
            }
            if (result == null) {
                result = score;
                continue;
            }
            if (Objects.equals(score, result)) continue;
            return null;
        }
        return result;
    }
}

