package org.encog.ml.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.encog.EncogError;
import org.encog.NullStatusReportable;
import org.encog.StatusReportable;
import org.encog.mathutil.randomize.generate.MersenneTwisterGenerateRandom;
import org.encog.ml.MLClassification;
import org.encog.ml.MLMethod;
import org.encog.ml.MLRegression;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.cross.DataFold;
import org.encog.ml.data.cross.KFoldCrossvalidation;
import org.encog.ml.data.versatile.MatrixMLDataSet;
import org.encog.ml.data.versatile.VersatileMLDataSet;
import org.encog.ml.data.versatile.columns.ColumnDefinition;
import org.encog.ml.data.versatile.columns.ColumnType;
import org.encog.ml.data.versatile.division.DataDivision;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.model.config.FeedforwardConfig;
import org.encog.ml.model.config.MethodConfig;
import org.encog.ml.model.config.NEATConfig;
import org.encog.ml.model.config.PNNConfig;
import org.encog.ml.model.config.RBFNetworkConfig;
import org.encog.ml.model.config.SVMConfig;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.end.SimpleEarlyStoppingStrategy;
import org.encog.util.Format;
import org.encog.util.simple.EncogUtility;

/* loaded from: classes.dex */
public class EncogModel {
    private MethodConfig config;
    private final VersatileMLDataSet dataset;
    private String methodArgs;
    private String methodType;
    private String trainingArgs;
    private MatrixMLDataSet trainingDataset;
    private String trainingType;
    private MatrixMLDataSet validationDataset;
    private final List<ColumnDefinition> inputFeatures = new ArrayList();
    private final List<ColumnDefinition> predictedFeatures = new ArrayList();
    private final Map<String, MethodConfig> methodConfigurations = new HashMap();
    private StatusReportable report = new NullStatusReportable();

    public EncogModel(VersatileMLDataSet versatileMLDataSet) {
        this.dataset = versatileMLDataSet;
        this.methodConfigurations.put(MLMethodFactory.TYPE_FEEDFORWARD, new FeedforwardConfig());
        this.methodConfigurations.put(MLMethodFactory.TYPE_SVM, new SVMConfig());
        this.methodConfigurations.put(MLMethodFactory.TYPE_RBFNETWORK, new RBFNetworkConfig());
        this.methodConfigurations.put(MLMethodFactory.TYPE_NEAT, new NEATConfig());
        this.methodConfigurations.put("pnn", new PNNConfig());
    }

    private MLTrain createTrainer(MLMethod mLMethod, MLDataSet mLDataSet) {
        if (this.trainingType != null) {
            return new MLTrainFactory().create(mLMethod, mLDataSet, this.trainingType, this.trainingArgs);
        }
        throw new EncogError("Please call selectTraining first to choose how to train.");
    }

    private void fitFold(int i, int i2, DataFold dataFold) {
        MLMethod createMethod = createMethod();
        MLTrain createTrainer = createTrainer(createMethod, dataFold.getTraining());
        if (createTrainer.getImplementationType() != TrainingImplementationType.Iterative) {
            if (createTrainer.getImplementationType() != TrainingImplementationType.OnePass) {
                throw new EncogError("Unsupported training type for EncogModel: " + createTrainer.getImplementationType());
            }
            createTrainer.iteration();
            double calculateError = calculateError(createMethod, dataFold.getValidation());
            this.report.report(i, i, "Trained, Training Error: " + createTrainer.getError() + ", Validatoin Error: " + calculateError);
            dataFold.setScore(calculateError);
            dataFold.setMethod(createMethod);
            return;
        }
        SimpleEarlyStoppingStrategy simpleEarlyStoppingStrategy = new SimpleEarlyStoppingStrategy(dataFold.getValidation());
        createTrainer.addStrategy(simpleEarlyStoppingStrategy);
        StringBuilder sb = new StringBuilder();
        while (!createTrainer.isTrainingDone()) {
            createTrainer.iteration();
            sb.setLength(0);
            sb.append("Fold #");
            sb.append(i2);
            sb.append("/");
            sb.append(i);
            sb.append(": Iteration #");
            sb.append(createTrainer.getIteration());
            sb.append(", Training Error: ");
            sb.append(Format.formatDouble(createTrainer.getError(), 8));
            sb.append(", Validation Error: ");
            sb.append(Format.formatDouble(simpleEarlyStoppingStrategy.getValidationError(), 8));
            this.report.report(i, i2, sb.toString());
        }
        dataFold.setScore(simpleEarlyStoppingStrategy.getValidationError());
        dataFold.setMethod(createMethod);
    }

    public double calculateError(MLMethod mLMethod, MLDataSet mLDataSet) {
        return (this.dataset.getNormHelper().getOutputColumns().size() == 1 && this.dataset.getNormHelper().getOutputColumns().get(0).getDataType() == ColumnType.nominal) ? EncogUtility.calculateClassificationError((MLClassification) mLMethod, mLDataSet) : EncogUtility.calculateRegressionError((MLRegression) mLMethod, mLDataSet);
    }

    public MLMethod createMethod() {
        if (this.methodType != null) {
            return new MLMethodFactory().create(this.methodType, this.methodArgs, this.dataset.getNormHelper().calculateNormalizedInputCount(), this.config.determineOutputCount(this.dataset));
        }
        throw new EncogError("Please call selectMethod first to choose what type of method you wish to use.");
    }

    public MLMethod crossvalidate(int i, boolean z) {
        KFoldCrossvalidation kFoldCrossvalidation = new KFoldCrossvalidation(this.trainingDataset, i);
        kFoldCrossvalidation.process(z);
        int i2 = 0;
        for (DataFold dataFold : kFoldCrossvalidation.getFolds()) {
            i2++;
            this.report.report(i, i2, "Fold #" + i2);
            fitFold(i, i2, dataFold);
        }
        double d = 0.0d;
        double d2 = Double.POSITIVE_INFINITY;
        MLMethod mLMethod = null;
        for (DataFold dataFold2 : kFoldCrossvalidation.getFolds()) {
            d += dataFold2.getScore();
            if (dataFold2.getScore() < d2) {
                d2 = dataFold2.getScore();
                mLMethod = dataFold2.getMethod();
            }
        }
        double size = d / kFoldCrossvalidation.getFolds().size();
        this.report.report(i, i, "Cross-validated score:" + size);
        return mLMethod;
    }

    public VersatileMLDataSet getDataset() {
        return this.dataset;
    }

    public List<ColumnDefinition> getInputFeatures() {
        return this.inputFeatures;
    }

    public Map<String, MethodConfig> getMethodConfigurations() {
        return this.methodConfigurations;
    }

    public List<ColumnDefinition> getPredictedFeatures() {
        return this.predictedFeatures;
    }

    public StatusReportable getReport() {
        return this.report;
    }

    public MatrixMLDataSet getTrainingDataset() {
        return this.trainingDataset;
    }

    public MatrixMLDataSet getValidationDataset() {
        return this.validationDataset;
    }

    public void holdBackValidation(double d, boolean z, int i) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new DataDivision(1.0d - d));
        arrayList.add(new DataDivision(d));
        this.dataset.divide(arrayList, z, new MersenneTwisterGenerateRandom(i));
        this.trainingDataset = ((DataDivision) arrayList.get(0)).getDataset();
        this.validationDataset = ((DataDivision) arrayList.get(1)).getDataset();
    }

    public void selectMethod(VersatileMLDataSet versatileMLDataSet, String str) {
        if (!this.methodConfigurations.containsKey(str)) {
            throw new EncogError("Don't know how to autoconfig method: " + str);
        }
        MethodConfig methodConfig = this.methodConfigurations.get(str);
        this.config = methodConfig;
        this.methodType = str;
        this.methodArgs = methodConfig.suggestModelArchitecture(versatileMLDataSet);
        versatileMLDataSet.getNormHelper().setStrategy(this.config.suggestNormalizationStrategy(versatileMLDataSet, this.methodArgs));
    }

    public void selectMethod(VersatileMLDataSet versatileMLDataSet, String str, String str2, String str3, String str4) {
        if (this.methodConfigurations.containsKey(str)) {
            this.methodType = str;
            this.methodArgs = str2;
            versatileMLDataSet.getNormHelper().setStrategy(this.methodConfigurations.get(str).suggestNormalizationStrategy(versatileMLDataSet, str2));
        } else {
            throw new EncogError("Don't know how to autoconfig method: " + str);
        }
    }

    public void selectTraining(VersatileMLDataSet versatileMLDataSet, String str, String str2) {
        if (this.methodType == null) {
            throw new EncogError("Please select your training method, before your training type.");
        }
        this.trainingType = str;
        this.trainingArgs = str2;
    }

    public void selectTrainingType(VersatileMLDataSet versatileMLDataSet) {
        String str = this.methodType;
        if (str == null) {
            throw new EncogError("Please select your training method, before your training type.");
        }
        MethodConfig methodConfig = this.methodConfigurations.get(str);
        selectTraining(versatileMLDataSet, methodConfig.suggestTrainingType(), methodConfig.suggestTrainingArgs(this.trainingType));
    }

    public void setReport(StatusReportable statusReportable) {
        this.report = statusReportable;
    }

    public void setTrainingDataset(MatrixMLDataSet matrixMLDataSet) {
        this.trainingDataset = matrixMLDataSet;
    }

    public void setValidationDataset(MatrixMLDataSet matrixMLDataSet) {
        this.validationDataset = matrixMLDataSet;
    }
}
