package org.encog.neural.networks.training.strategy;

import org.encog.EncogError;
import org.encog.ml.MLEncodable;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.util.EngineArray;

/* loaded from: classes.dex */
public class RegularizationStrategy implements Strategy {
    private MLEncodable encodable;
    private double lambda;
    private double[] newWeights;
    private MLTrain train;
    private double[] weights;

    public RegularizationStrategy(double d) {
        this.lambda = d;
    }

    @Override // org.encog.ml.train.strategy.Strategy
    public void init(MLTrain mLTrain) {
        this.train = mLTrain;
        if (!(mLTrain.getMethod() instanceof MLEncodable)) {
            throw new EncogError("Method must implement MLEncodable to be used with regularization.");
        }
        MLEncodable mLEncodable = (MLEncodable) mLTrain.getMethod();
        this.encodable = mLEncodable;
        this.weights = new double[mLEncodable.encodedArrayLength()];
        this.newWeights = new double[this.encodable.encodedArrayLength()];
    }

    @Override // org.encog.ml.train.strategy.Strategy
    public void postIteration() {
        this.encodable.encodeToArray(this.newWeights);
        int i = 0;
        while (true) {
            double[] dArr = this.newWeights;
            if (i >= dArr.length) {
                this.encodable.decodeFromArray(dArr);
                EngineArray.arrayCopy(this.newWeights, this.weights);
                return;
            } else {
                dArr[i] = dArr[i] - (this.lambda * this.weights[i]);
                i++;
            }
        }
    }

    @Override // org.encog.ml.train.strategy.Strategy
    public void preIteration() {
        ((MLEncodable) this.train.getMethod()).encodeToArray(this.weights);
    }
}
