/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.common.tree;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayDeque;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.DecisionTreeTrainer;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.util.Util;

public abstract class AbstractCARTTrainer<T extends Output<T>>
implements DecisionTreeTrainer<T> {
    public static final int MIN_EXAMPLES = 5;
    @Config(description="The minimum weight allowed in a child node.")
    protected float minChildWeight = 5.0f;
    @Config(description="The maximum depth of the tree.")
    protected int maxDepth = Integer.MAX_VALUE;
    @Config(description="The decrease in impurity needed in order to split the node.")
    protected float minImpurityDecrease = 0.0f;
    @Config(description="The fraction of features to consider in each split. 1.0f indicates all features are considered.")
    protected float fractionFeaturesInSplit = 1.0f;
    @Config(description="Whether to choose split points for features at random.")
    protected boolean useRandomSplitPoints = false;
    @Config(description="The RNG seed to use when sampling features in a split.")
    protected long seed = 12345L;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    protected AbstractCARTTrainer(int maxDepth, float minChildWeight, float minImpurityDecrease, float fractionFeaturesInSplit, boolean useRandomSplitPoints, long seed) {
        this.maxDepth = maxDepth;
        this.fractionFeaturesInSplit = fractionFeaturesInSplit;
        this.useRandomSplitPoints = useRandomSplitPoints;
        this.minChildWeight = minChildWeight;
        this.minImpurityDecrease = minImpurityDecrease;
        this.seed = seed;
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
        if (this.fractionFeaturesInSplit <= 0.0f || this.fractionFeaturesInSplit > 1.0f) {
            throw new IllegalArgumentException("fractionFeaturesInSplit must be greater than 0 and less than or equal to 1");
        }
        if (this.minImpurityDecrease < 0.0f) {
            throw new IllegalArgumentException("minImpurityDecrease must be greater than or equal to 0");
        }
        if (this.maxDepth < 1) {
            throw new IllegalArgumentException("maxDepth must be greater than or equal to 1");
        }
        if (this.minChildWeight <= 0.0f) {
            throw new IllegalArgumentException("minChildWeight must be greater than 0");
        }
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    @Override
    public float getFractionFeaturesInSplit() {
        return this.fractionFeaturesInSplit;
    }

    @Override
    public boolean getUseRandomSplitPoints() {
        return this.useRandomSplitPoints;
    }

    @Override
    public float getMinImpurityDecrease() {
        return this.minImpurityDecrease;
    }

    public TreeModel<T> train(Dataset<T> examples) {
        return this.train((Dataset)examples, Collections.emptyMap());
    }

    public TreeModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public TreeModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        AbstractCARTTrainer abstractCARTTrainer = this;
        synchronized (abstractCARTTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = (TrainerProvenance)this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableFeatureMap featureIDMap = examples.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = examples.getOutputIDInfo();
        int numFeaturesInSplit = Math.min(Math.round(this.fractionFeaturesInSplit * (float)featureIDMap.size()), featureIDMap.size());
        int[] originalIndices = new int[featureIDMap.size()];
        for (int i = 0; i < originalIndices.length; ++i) {
            originalIndices[i] = i;
        }
        int[] indices = numFeaturesInSplit != featureIDMap.size() ? new int[numFeaturesInSplit] : originalIndices;
        float weightSum = 0.0f;
        for (Example e : examples) {
            weightSum += e.getWeight();
        }
        float scaledMinImpurityDecrease = this.getMinImpurityDecrease() * weightSum;
        AbstractTrainingNode.LeafDeterminer leafDeterminer = new AbstractTrainingNode.LeafDeterminer(this.maxDepth, this.minChildWeight, scaledMinImpurityDecrease);
        AbstractTrainingNode<T> root = this.mkTrainingNode(examples, leafDeterminer);
        ArrayDeque queue = new ArrayDeque();
        queue.add(root);
        while (!queue.isEmpty()) {
            AbstractTrainingNode node = (AbstractTrainingNode)queue.poll();
            if (!(node.getImpurity() > 0.0) || node.getDepth() >= this.maxDepth || !(node.getWeightSum() >= this.minChildWeight)) continue;
            if (numFeaturesInSplit != featureIDMap.size()) {
                Util.randpermInPlace((int[])originalIndices, (SplittableRandom)localRNG);
                System.arraycopy(originalIndices, 0, indices, 0, numFeaturesInSplit);
            }
            List nodes = node.buildTree(indices, localRNG, this.getUseRandomSplitPoints());
            for (AbstractTrainingNode newNode : nodes) {
                queue.addFirst(newNode);
            }
        }
        ModelProvenance provenance = new ModelProvenance(TreeModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return new TreeModel<T>("cart-tree", provenance, featureIDMap, outputIDInfo, false, root.convertTree());
    }

    protected abstract AbstractTrainingNode<T> mkTrainingNode(Dataset<T> var1, AbstractTrainingNode.LeafDeterminer var2);

    @Deprecated
    protected static abstract class AbstractCARTTrainerProvenance
    extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1L;

        protected <T extends Output<T>> AbstractCARTTrainerProvenance(AbstractCARTTrainer<T> host) {
            super(host);
        }

        protected AbstractCARTTrainerProvenance(Map<String, Provenance> map) {
            super(map);
        }
    }
}

