/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.crf;

import java.util.Arrays;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;

public final class ChainHelper {
    private ChainHelper() {
    }

    public static ChainBPResults beliefPropagation(ChainCliqueValues scores) {
        int i;
        int numLabels = scores.transitionValues.getDimension1Size();
        DenseMatrix markovScores = scores.transitionValues;
        DenseVector[] localScores = scores.localValues;
        DenseVector[] alphas = new DenseVector[localScores.length];
        DenseVector[] betas = new DenseVector[localScores.length];
        for (int i2 = 0; i2 < localScores.length; ++i2) {
            alphas[i2] = localScores[i2].copy();
            betas[i2] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
        }
        double[] tmpArray = new double[numLabels];
        for (i = 1; i < localScores.length; ++i) {
            DenseVector curAlpha = alphas[i];
            DenseVector prevAlpha = alphas[i - 1];
            for (int vi = 0; vi < numLabels; ++vi) {
                for (int vj = 0; vj < numLabels; ++vj) {
                    tmpArray[vj] = markovScores.get(vj, vi) + prevAlpha.get(vj);
                }
                curAlpha.add(vi, ChainHelper.sumLogProbs(tmpArray));
            }
        }
        betas[betas.length - 1].fill(0.0);
        for (i = localScores.length - 2; i >= 0; --i) {
            DenseVector curBeta = betas[i];
            DenseVector prevBeta = betas[i + 1];
            DenseVector prevLocalScore = localScores[i + 1];
            for (int vi = 0; vi < numLabels; ++vi) {
                for (int vj = 0; vj < numLabels; ++vj) {
                    tmpArray[vj] = markovScores.get(vi, vj) + prevBeta.get(vj) + prevLocalScore.get(vj);
                }
                curBeta.set(vi, ChainHelper.sumLogProbs(tmpArray));
            }
        }
        double logZ = ChainHelper.sumLogProbs(alphas[alphas.length - 1]);
        return new ChainBPResults(logZ, alphas, betas, scores);
    }

    public static double constrainedBeliefPropagation(ChainCliqueValues scores, int[] constraints) {
        int numLabels = scores.transitionValues.getDimension1Size();
        DenseMatrix markovScores = scores.transitionValues;
        DenseVector[] localScores = scores.localValues;
        if (localScores.length != constraints.length) {
            throw new IllegalArgumentException("Must have the same number of constraints as tokens");
        }
        DenseVector[] alphas = new DenseVector[localScores.length];
        for (int i = 0; i < localScores.length; ++i) {
            alphas[i] = localScores[i].copy();
        }
        double[] tmpArray = new double[numLabels];
        for (int i = 1; i < localScores.length; ++i) {
            DenseVector curAlpha = alphas[i];
            DenseVector prevAlpha = alphas[i - 1];
            for (int vi = 0; vi < numLabels; ++vi) {
                if (constraints[i] == -1 || constraints[i] == vi) {
                    for (int vj = 0; vj < numLabels; ++vj) {
                        tmpArray[vj] = markovScores.get(vj, vi) + prevAlpha.get(vj);
                    }
                    curAlpha.add(vi, ChainHelper.sumLogProbs(tmpArray));
                    continue;
                }
                curAlpha.set(vi, Double.NEGATIVE_INFINITY);
            }
        }
        return ChainHelper.sumLogProbs(alphas[alphas.length - 1]);
    }

    public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
        int i;
        DenseMatrix markovScores = scores.transitionValues;
        DenseVector[] localScores = scores.localValues;
        int numLabels = markovScores.getDimension1Size();
        DenseVector[] costs = new DenseVector[scores.localValues.length];
        int[][] backPointers = new int[scores.localValues.length][];
        for (i = 0; i < scores.localValues.length; ++i) {
            costs[i] = new DenseVector(numLabels, Double.NEGATIVE_INFINITY);
            backPointers[i] = new int[numLabels];
            Arrays.fill(backPointers[i], -1);
        }
        costs[0].setElements(localScores[0]);
        for (i = 1; i < scores.localValues.length; ++i) {
            DenseVector curLocalScores = localScores[i];
            DenseVector curCost = costs[i];
            int[] curBackPointers = backPointers[i];
            DenseVector prevCost = costs[i - 1];
            for (int vi = 0; vi < numLabels; ++vi) {
                double maxScore = Double.NEGATIVE_INFINITY;
                int maxIndex = -1;
                double curLocalScore = curLocalScores.get(vi);
                for (int vj = 0; vj < numLabels; ++vj) {
                    double curScore = markovScores.get(vj, vi) + prevCost.get(vj) + curLocalScore;
                    if (!(curScore > maxScore)) continue;
                    maxScore = curScore;
                    maxIndex = vj;
                }
                curCost.set(vi, maxScore);
                if (maxIndex < 0) {
                    maxIndex = 0;
                }
                curBackPointers[vi] = maxIndex;
            }
        }
        int[] mapValues = new int[scores.localValues.length];
        mapValues[mapValues.length - 1] = costs[costs.length - 1].indexOfMax();
        for (int j = mapValues.length - 2; j >= 0; --j) {
            mapValues[j] = backPointers[j + 1][mapValues[j + 1]];
        }
        return new ChainViterbiResults(costs[costs.length - 1].maxValue(), mapValues, scores);
    }

    public static double sumLogProbs(DenseVector input) {
        double LOG_TOLERANCE = 30.0;
        double maxValue = input.get(0);
        int maxIdx = 0;
        for (int i = 1; i < input.size(); ++i) {
            double value = input.get(i);
            if (!(value > maxValue)) continue;
            maxValue = value;
            maxIdx = i;
        }
        if (maxValue == Double.NEGATIVE_INFINITY) {
            return maxValue;
        }
        boolean anyAdded = false;
        double intermediate = 0.0;
        double cutoff = maxValue - LOG_TOLERANCE;
        for (int i = 0; i < input.size(); ++i) {
            double value = input.get(i);
            if (!(value >= cutoff) || i == maxIdx || Double.isInfinite(value)) continue;
            anyAdded = true;
            intermediate += Math.exp(value - maxValue);
        }
        if (anyAdded) {
            return maxValue + Math.log1p(intermediate);
        }
        return maxValue;
    }

    public static double sumLogProbs(double[] input) {
        double LOG_TOLERANCE = 30.0;
        double maxValue = input[0];
        int maxIdx = 0;
        for (int i = 1; i < input.length; ++i) {
            double value = input[i];
            if (!(value > maxValue)) continue;
            maxValue = value;
            maxIdx = i;
        }
        if (maxValue == Double.NEGATIVE_INFINITY) {
            return maxValue;
        }
        boolean anyAdded = false;
        double intermediate = 0.0;
        double cutoff = maxValue - LOG_TOLERANCE;
        for (int i = 0; i < input.length; ++i) {
            if (!(input[i] >= cutoff) || i == maxIdx || Double.isInfinite(input[i])) continue;
            anyAdded = true;
            intermediate += Math.exp(input[i] - maxValue);
        }
        if (anyAdded) {
            return maxValue + Math.log1p(intermediate);
        }
        return maxValue;
    }

    public static final class ChainViterbiResults {
        public final double mapScore;
        public final int[] mapValues;
        public final ChainCliqueValues scores;

        ChainViterbiResults(double mapScore, int[] mapValues, ChainCliqueValues scores) {
            this.mapScore = mapScore;
            this.mapValues = mapValues;
            this.scores = scores;
        }
    }

    public static final class ChainCliqueValues {
        public final DenseVector[] localValues;
        public final DenseMatrix transitionValues;

        ChainCliqueValues(DenseVector[] localValues, DenseMatrix transitionValues) {
            this.localValues = localValues;
            this.transitionValues = transitionValues;
        }
    }

    public static final class ChainBPResults {
        public final double logZ;
        public final DenseVector[] alphas;
        public final DenseVector[] betas;
        public final ChainCliqueValues scores;

        ChainBPResults(double logZ, DenseVector[] alphas, DenseVector[] betas, ChainCliqueValues scores) {
            this.logZ = logZ;
            this.alphas = alphas;
            this.betas = betas;
            this.scores = scores;
        }
    }
}

