package org.encog.ml.hmm.alog;

import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;

/* loaded from: classes.dex */
public class ForwardBackwardScaledCalculator extends ForwardBackwardCalculator {
    private final double[] ctFactors;
    private double lnProbability;

    public ForwardBackwardScaledCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel) {
        this(mLDataSet, hiddenMarkovModel, EnumSet.of(ForwardBackwardCalculator.Computation.ALPHA));
    }

    public ForwardBackwardScaledCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel, EnumSet<ForwardBackwardCalculator.Computation> enumSet) {
        if (mLDataSet.size() < 1) {
            throw new IllegalArgumentException();
        }
        double[] dArr = new double[mLDataSet.size()];
        this.ctFactors = dArr;
        Arrays.fill(dArr, 0.0d);
        computeAlpha(hiddenMarkovModel, mLDataSet);
        if (enumSet.contains(ForwardBackwardCalculator.Computation.BETA)) {
            computeBeta(hiddenMarkovModel, mLDataSet);
        }
        computeProbability(mLDataSet, hiddenMarkovModel, enumSet);
    }

    private void computeProbability(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel, EnumSet<ForwardBackwardCalculator.Computation> enumSet) {
        this.lnProbability = 0.0d;
        for (int i = 0; i < mLDataSet.size(); i++) {
            this.lnProbability += Math.log(this.ctFactors[i]);
        }
        this.probability = Math.exp(this.lnProbability);
    }

    private void scale(double[] dArr, double[][] dArr2, int i) {
        double[] dArr3 = dArr2[i];
        double d = 0.0d;
        for (double d2 : dArr3) {
            d += d2;
        }
        dArr[i] = d;
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            dArr3[i2] = dArr3[i2] / d;
        }
    }

    @Override // org.encog.ml.hmm.alog.ForwardBackwardCalculator
    protected void computeAlpha(HiddenMarkovModel hiddenMarkovModel, MLDataSet mLDataSet) {
        this.alpha = (double[][]) Array.newInstance((Class<?>) double.class, mLDataSet.size(), hiddenMarkovModel.getStateCount());
        for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
            computeAlphaInit(hiddenMarkovModel, mLDataSet.get(0), i);
        }
        scale(this.ctFactors, this.alpha, 0);
        Iterator<MLDataPair> it = mLDataSet.iterator();
        if (it.hasNext()) {
            it.next();
        }
        for (int i2 = 1; i2 < mLDataSet.size(); i2++) {
            MLDataPair next = it.next();
            for (int i3 = 0; i3 < hiddenMarkovModel.getStateCount(); i3++) {
                computeAlphaStep(hiddenMarkovModel, next, i2, i3);
            }
            scale(this.ctFactors, this.alpha, i2);
        }
    }

    @Override // org.encog.ml.hmm.alog.ForwardBackwardCalculator
    protected void computeBeta(HiddenMarkovModel hiddenMarkovModel, MLDataSet mLDataSet) {
        this.beta = (double[][]) Array.newInstance((Class<?>) double.class, mLDataSet.size(), hiddenMarkovModel.getStateCount());
        for (int i = 0; i < hiddenMarkovModel.getStateCount(); i++) {
            this.beta[mLDataSet.size() - 1][i] = 1.0d / this.ctFactors[mLDataSet.size() - 1];
        }
        for (int size = mLDataSet.size() - 2; size >= 0; size--) {
            for (int i2 = 0; i2 < hiddenMarkovModel.getStateCount(); i2++) {
                computeBetaStep(hiddenMarkovModel, mLDataSet.get(size + 1), size, i2);
                double[] dArr = this.beta[size];
                dArr[i2] = dArr[i2] / this.ctFactors[size];
            }
        }
    }

    public double lnProbability() {
        return this.lnProbability;
    }
}
