package opennlp.tools.ml.perceptron;

import java.io.IOException;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import opennlp.tools.ml.AbstractEventModelSequenceTrainer;
import opennlp.tools.ml.model.AbstractDataIndexer;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Event;
import opennlp.tools.ml.model.MutableContext;
import opennlp.tools.ml.model.OnePassDataIndexer;
import opennlp.tools.ml.model.Sequence;
import opennlp.tools.ml.model.SequenceStream;
import opennlp.tools.ml.model.SequenceStreamEventStream;

/* loaded from: classes2.dex */
public class SimplePerceptronSequenceTrainer extends AbstractEventModelSequenceTrainer {
    private static final int EVENT = 2;
    private static final int ITER = 1;
    public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
    private static final int VALUE = 0;
    private MutableContext[] averageParams;
    private int iterations;
    private int numEvents;
    private int numOutcomes;
    private int numPreds;
    private int numSequences;
    private Map<String, Integer> omap;
    private String[] outcomeLabels;
    private int[] outcomeList;
    private MutableContext[] params;
    private Map<String, Integer> pmap;
    private String[] predLabels;
    private SequenceStream sequenceStream;
    private int[][][] updates;
    private boolean useAverage;

    private void findParameters(int i) throws IOException {
        display("Performing " + i + " iterations.\n");
        for (int i2 = 1; i2 <= i; i2++) {
            if (i2 < 10) {
                display("  " + i2 + ":  ");
            } else if (i2 < 100) {
                display(" " + i2 + ":  ");
            } else {
                display(i2 + ":  ");
            }
            nextIteration(i2);
        }
        if (this.useAverage) {
            trainingStats(this.averageParams);
        } else {
            trainingStats(this.params);
        }
    }

    private void trainingStats(MutableContext[] mutableContextArr) throws IOException {
        this.sequenceStream.reset();
        int i = 0;
        int i2 = 0;
        while (true) {
            Sequence read = this.sequenceStream.read();
            if (read == null) {
                StringBuilder sb = new StringBuilder();
                sb.append(". (");
                sb.append(i);
                sb.append("/");
                sb.append(this.numEvents);
                sb.append(") ");
                double d = i;
                double d2 = this.numEvents;
                Double.isNaN(d);
                Double.isNaN(d2);
                sb.append(d / d2);
                sb.append("\n");
                display(sb.toString());
                return;
            }
            Event[] updateContext = this.sequenceStream.updateContext(read, new PerceptronModel(mutableContextArr, this.predLabels, this.outcomeLabels));
            int i3 = i2;
            int i4 = i;
            int i5 = 0;
            while (i5 < updateContext.length) {
                if (this.omap.get(updateContext[i5].getOutcome()).intValue() == this.outcomeList[i3]) {
                    i4++;
                }
                i5++;
                i3++;
            }
            i = i4;
            i2 = i3;
        }
    }

    @Override // opennlp.tools.ml.AbstractEventModelSequenceTrainer
    public AbstractModel doTrain(SequenceStream sequenceStream) throws IOException {
        return trainModel(getIterations(), sequenceStream, getCutoff(), this.trainingParameters.getBooleanParameter("UseAverage", true));
    }

    @Override // opennlp.tools.ml.AbstractTrainer
    @Deprecated
    public boolean isValid() {
        try {
            validate();
            return true;
        } catch (IllegalArgumentException unused) {
            return false;
        }
    }

    public void nextIteration(int i) throws IOException {
        int i2 = i - 1;
        ArrayList arrayList = new ArrayList(this.numOutcomes);
        for (int i3 = 0; i3 < this.numOutcomes; i3++) {
            arrayList.add(new HashMap());
        }
        PerceptronModel perceptronModel = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
        this.sequenceStream.reset();
        PerceptronModel perceptronModel2 = perceptronModel;
        int i4 = 0;
        int i5 = 0;
        while (true) {
            Sequence read = this.sequenceStream.read();
            if (read == null) {
                break;
            }
            Event[] updateContext = this.sequenceStream.updateContext(read, perceptronModel2);
            Event[] events = read.getEvents();
            int i6 = i5;
            boolean z = false;
            for (int i7 = 0; i7 < events.length; i7++) {
                if (updateContext[i7].getOutcome().equals(events[i7].getOutcome())) {
                    i6++;
                } else {
                    z = true;
                }
            }
            if (z) {
                for (int i8 = 0; i8 < this.numOutcomes; i8++) {
                    ((Map) arrayList.get(i8)).clear();
                }
                for (int i9 = 0; i9 < events.length; i9++) {
                    String[] context = events[i9].getContext();
                    float[] values = events[i9].getValues();
                    int intValue = this.omap.get(events[i9].getOutcome()).intValue();
                    for (int i10 = 0; i10 < context.length; i10++) {
                        float f = values != null ? values[i10] : 1.0f;
                        Float f2 = (Float) ((Map) arrayList.get(intValue)).get(context[i10]);
                        ((Map) arrayList.get(intValue)).put(context[i10], f2 == null ? Float.valueOf(f) : Float.valueOf(f2.floatValue() + f));
                    }
                }
                for (Event event : updateContext) {
                    String[] context2 = event.getContext();
                    float[] values2 = event.getValues();
                    int intValue2 = this.omap.get(event.getOutcome()).intValue();
                    for (int i11 = 0; i11 < context2.length; i11++) {
                        float f3 = values2 != null ? values2[i11] : 1.0f;
                        Float f4 = (Float) ((Map) arrayList.get(intValue2)).get(context2[i11]);
                        Float valueOf = f4 == null ? Float.valueOf(f3 * (-1.0f)) : Float.valueOf(f4.floatValue() - f3);
                        if (valueOf.floatValue() == 0.0f) {
                            ((Map) arrayList.get(intValue2)).remove(context2[i11]);
                        } else {
                            ((Map) arrayList.get(intValue2)).put(context2[i11], valueOf);
                        }
                    }
                }
                for (int i12 = 0; i12 < this.numOutcomes; i12++) {
                    Iterator it = ((Map) arrayList.get(i12)).keySet().iterator();
                    while (it.hasNext()) {
                        int intValue3 = this.pmap.get((String) it.next()).intValue();
                        if (intValue3 != -1) {
                            this.params[intValue3].updateParameter(i12, ((Float) ((Map) arrayList.get(i12)).get(r7)).floatValue());
                            if (this.useAverage) {
                                if (this.updates[intValue3][i12][0] != 0) {
                                    this.averageParams[intValue3].updateParameter(i12, r7[intValue3][i12][0] * ((this.numSequences * (i2 - r7[intValue3][i12][1])) + (i4 - r7[intValue3][i12][2])));
                                }
                                this.updates[intValue3][i12][0] = (int) this.params[intValue3].getParameters()[i12];
                                int[][][] iArr = this.updates;
                                iArr[intValue3][i12][1] = i2;
                                iArr[intValue3][i12][2] = i4;
                            }
                        }
                    }
                }
                perceptronModel2 = new PerceptronModel(this.params, this.predLabels, this.outcomeLabels);
            }
            i4++;
            i5 = i6;
        }
        int i13 = this.iterations;
        double d = i13;
        double d2 = i4;
        Double.isNaN(d);
        Double.isNaN(d2);
        double d3 = d * d2;
        if (this.useAverage && i2 == i13 - 1) {
            for (int i14 = 0; i14 < this.numPreds; i14++) {
                double[] parameters = this.averageParams[i14].getParameters();
                for (int i15 = 0; i15 < this.numOutcomes; i15++) {
                    int[][][] iArr2 = this.updates;
                    if (iArr2[i14][i15][0] != 0) {
                        double d4 = parameters[i15];
                        double d5 = iArr2[i14][i15][0] * ((this.numSequences * (this.iterations - iArr2[i14][i15][1])) - iArr2[i14][i15][2]);
                        Double.isNaN(d5);
                        parameters[i15] = d4 + d5;
                    }
                    if (parameters[i15] != 0.0d) {
                        parameters[i15] = parameters[i15] / d3;
                        this.averageParams[i14].setParameter(i15, parameters[i15]);
                    }
                }
            }
        }
        StringBuilder sb = new StringBuilder();
        sb.append(". (");
        sb.append(i5);
        sb.append("/");
        sb.append(this.numEvents);
        sb.append(") ");
        double d6 = i5;
        double d7 = this.numEvents;
        Double.isNaN(d6);
        Double.isNaN(d7);
        sb.append(d6 / d7);
        sb.append("\n");
        display(sb.toString());
    }

    public AbstractModel trainModel(int i, SequenceStream sequenceStream, int i2, boolean z) throws IOException {
        this.iterations = i;
        this.sequenceStream = sequenceStream;
        this.trainingParameters.put("Cutoff", i2);
        this.trainingParameters.put(AbstractDataIndexer.SORT_PARAM, false);
        OnePassDataIndexer onePassDataIndexer = new OnePassDataIndexer();
        onePassDataIndexer.init(this.trainingParameters, this.reportMap);
        onePassDataIndexer.index(new SequenceStreamEventStream(sequenceStream));
        this.numSequences = 0;
        sequenceStream.reset();
        while (sequenceStream.read() != null) {
            this.numSequences++;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        this.predLabels = onePassDataIndexer.getPredLabels();
        this.pmap = new HashMap();
        int i3 = 0;
        while (true) {
            String[] strArr = this.predLabels;
            if (i3 >= strArr.length) {
                break;
            }
            this.pmap.put(strArr[i3], Integer.valueOf(i3));
            i3++;
        }
        display("Incorporating indexed data for training...  \n");
        this.useAverage = z;
        this.numEvents = onePassDataIndexer.getNumEvents();
        this.iterations = i;
        this.outcomeLabels = onePassDataIndexer.getOutcomeLabels();
        this.omap = new HashMap();
        int i4 = 0;
        while (true) {
            String[] strArr2 = this.outcomeLabels;
            if (i4 >= strArr2.length) {
                break;
            }
            this.omap.put(strArr2[i4], Integer.valueOf(i4));
            i4++;
        }
        this.outcomeList = onePassDataIndexer.getOutcomeList();
        this.numPreds = this.predLabels.length;
        this.numOutcomes = this.outcomeLabels.length;
        if (z) {
            this.updates = (int[][][]) Array.newInstance((Class<?>) int.class, this.numPreds, this.numOutcomes, 3);
        }
        display("done.\n");
        display("\tNumber of Event Tokens: " + this.numEvents + "\n");
        display("\t    Number of Outcomes: " + this.numOutcomes + "\n");
        display("\t  Number of Predicates: " + this.numPreds + "\n");
        int i5 = this.numPreds;
        this.params = new MutableContext[i5];
        if (z) {
            this.averageParams = new MutableContext[i5];
        }
        int[] iArr = new int[this.numOutcomes];
        for (int i6 = 0; i6 < this.numOutcomes; i6++) {
            iArr[i6] = i6;
        }
        for (int i7 = 0; i7 < this.numPreds; i7++) {
            this.params[i7] = new MutableContext(iArr, new double[this.numOutcomes]);
            if (z) {
                this.averageParams[i7] = new MutableContext(iArr, new double[this.numOutcomes]);
            }
            for (int i8 = 0; i8 < this.numOutcomes; i8++) {
                this.params[i7].setParameter(i8, 0.0d);
                if (z) {
                    this.averageParams[i7].setParameter(i8, 0.0d);
                }
            }
        }
        double[] dArr = new double[this.numOutcomes];
        display("Computing model parameters...\n");
        findParameters(i);
        display("...done.\n");
        String[] strArr3 = this.predLabels;
        return z ? new PerceptronModel(this.averageParams, strArr3, this.outcomeLabels) : new PerceptronModel(this.params, strArr3, this.outcomeLabels);
    }

    @Override // opennlp.tools.ml.AbstractTrainer
    public void validate() {
        super.validate();
        String algorithm = getAlgorithm();
        if (algorithm != null && !PERCEPTRON_SEQUENCE_VALUE.equals(algorithm)) {
            throw new IllegalArgumentException("algorithmName must be PERCEPTRON_SEQUENCE");
        }
    }
}
