package opennlp.tools.ml.maxent.quasinewton;

import java.lang.reflect.Array;
import opennlp.tools.ml.maxent.quasinewton.LineSearch;

/* loaded from: classes2.dex */
public class QNMinimizer {
    public static final double CONVERGE_TOLERANCE = 1.0E-4d;
    public static final double INITIAL_STEP_SIZE = 1.0d;
    public static final double L1COST_DEFAULT = 0.0d;
    public static final double L2COST_DEFAULT = 0.0d;
    public static final int MAX_FCT_EVAL_DEFAULT = 30000;
    public static final double MIN_STEP_SIZE = 1.0E-10d;
    public static final int M_DEFAULT = 15;
    public static final int NUM_ITERATIONS_DEFAULT = 100;
    public static final double REL_GRAD_NORM_TOL = 1.0E-4d;
    private int dimension;
    private Evaluator evaluator;
    private int iterations;
    private double l1Cost;
    private double l2Cost;
    private int m;
    private int maxFctEval;
    private UpdateInfo updateInfo;
    private boolean verbose;

    /* loaded from: classes2.dex */
    public interface Evaluator {
        double evaluate(double[] dArr);
    }

    /* loaded from: classes2.dex */
    public static class L2RegFunction implements Function {
        private Function f;
        private double l2Cost;

        public L2RegFunction(Function function, double d) {
            this.f = function;
            this.l2Cost = d;
        }

        private void checkDimension(double[] dArr) {
            if (dArr.length != getDimension()) {
                throw new IllegalArgumentException("x's dimension is not the same as function's dimension");
            }
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.Function
        public int getDimension() {
            return this.f.getDimension();
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.Function
        public double[] gradientAt(double[] dArr) {
            checkDimension(dArr);
            double[] gradientAt = this.f.gradientAt(dArr);
            if (this.l2Cost > 0.0d) {
                for (int i = 0; i < dArr.length; i++) {
                    gradientAt[i] = gradientAt[i] + (this.l2Cost * 2.0d * dArr[i]);
                }
            }
            return gradientAt;
        }

        @Override // opennlp.tools.ml.maxent.quasinewton.Function
        public double valueAt(double[] dArr) {
            checkDimension(dArr);
            double valueAt = this.f.valueAt(dArr);
            double d = this.l2Cost;
            return d > 0.0d ? valueAt + (d * opennlp.tools.ml.ArrayMath.innerProduct(dArr, dArr)) : valueAt;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes2.dex */
    public class UpdateInfo {
        private double[][] S;
        private double[][] Y;
        private double[] alpha;
        private int kCounter = 0;
        private int m;
        private double[] rho;

        UpdateInfo(int i, int i2) {
            this.m = i;
            this.S = (double[][]) Array.newInstance((Class<?>) double.class, this.m, i2);
            this.Y = (double[][]) Array.newInstance((Class<?>) double.class, this.m, i2);
            int i3 = this.m;
            this.rho = new double[i3];
            this.alpha = new double[i3];
        }

        public void update(LineSearch.LineSearchResult lineSearchResult) {
            double[] currPoint = lineSearchResult.getCurrPoint();
            double[] gradAtCurr = lineSearchResult.getGradAtCurr();
            double[] nextPoint = lineSearchResult.getNextPoint();
            double[] gradAtNext = lineSearchResult.getGradAtNext();
            int i = 0;
            double d = 0.0d;
            if (this.kCounter < this.m) {
                while (i < QNMinimizer.this.dimension) {
                    double[][] dArr = this.S;
                    int i2 = this.kCounter;
                    dArr[i2][i] = nextPoint[i] - currPoint[i];
                    double[][] dArr2 = this.Y;
                    dArr2[i2][i] = gradAtNext[i] - gradAtCurr[i];
                    d += dArr[i2][i] * dArr2[i2][i];
                    i++;
                }
                this.rho[this.kCounter] = 1.0d / d;
            } else {
                int i3 = 0;
                while (i3 < this.m - 1) {
                    double[][] dArr3 = this.S;
                    int i4 = i3 + 1;
                    dArr3[i3] = dArr3[i4];
                    double[][] dArr4 = this.Y;
                    dArr4[i3] = dArr4[i4];
                    double[] dArr5 = this.rho;
                    dArr5[i3] = dArr5[i4];
                    i3 = i4;
                }
                while (i < QNMinimizer.this.dimension) {
                    double[][] dArr6 = this.S;
                    int i5 = this.m;
                    dArr6[i5 - 1][i] = nextPoint[i] - currPoint[i];
                    double[][] dArr7 = this.Y;
                    dArr7[i5 - 1][i] = gradAtNext[i] - gradAtCurr[i];
                    d += dArr6[i5 - 1][i] * dArr7[i5 - 1][i];
                    i++;
                }
                this.rho[this.m - 1] = 1.0d / d;
            }
            int i6 = this.kCounter;
            if (i6 < this.m) {
                this.kCounter = i6 + 1;
            }
        }
    }

    public QNMinimizer() {
        this(0.0d, 0.0d);
    }

    public QNMinimizer(double d, double d2) {
        this(d, d2, 100);
    }

    public QNMinimizer(double d, double d2, int i) {
        this(d, d2, i, 15, 30000);
    }

    public QNMinimizer(double d, double d2, int i, int i2, int i3) {
        this(d, d2, i, i2, i3, true);
    }

    public QNMinimizer(double d, double d2, int i, int i2, int i3, boolean z) {
        if (d < 0.0d || d2 < 0.0d) {
            throw new IllegalArgumentException("L1-cost and L2-cost must not be less than zero");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Number of iterations must be larger than zero");
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Number of Hessian updates must be larger than zero");
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Maximum number of function evaluations must be larger than zero");
        }
        this.l1Cost = d;
        this.l2Cost = d2;
        this.iterations = i;
        this.m = i2;
        this.maxFctEval = i3;
        this.verbose = z;
    }

    private void computeDirection(double[] dArr) {
        int i;
        int i2 = this.updateInfo.kCounter;
        double[] dArr2 = this.updateInfo.rho;
        double[] dArr3 = this.updateInfo.alpha;
        double[][] dArr4 = this.updateInfo.S;
        double[][] dArr5 = this.updateInfo.Y;
        int i3 = i2 - 1;
        while (true) {
            i = 0;
            if (i3 < 0) {
                break;
            }
            dArr3[i3] = dArr2[i3] * opennlp.tools.ml.ArrayMath.innerProduct(dArr4[i3], dArr);
            while (i < this.dimension) {
                dArr[i] = dArr[i] - (dArr3[i3] * dArr5[i3][i]);
                i++;
            }
            i3--;
        }
        for (int i4 = 0; i4 < i2; i4++) {
            double innerProduct = dArr2[i4] * opennlp.tools.ml.ArrayMath.innerProduct(dArr5[i4], dArr);
            for (int i5 = 0; i5 < this.dimension; i5++) {
                dArr[i5] = dArr[i5] + (dArr4[i4][i5] * (dArr3[i4] - innerProduct));
            }
        }
        while (i < this.dimension) {
            dArr[i] = -dArr[i];
            i++;
        }
    }

    private void computePseudoGrad(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < this.dimension; i++) {
            if (dArr[i] < 0.0d) {
                dArr3[i] = dArr2[i] - this.l1Cost;
            } else if (dArr[i] > 0.0d) {
                dArr3[i] = dArr2[i] + this.l1Cost;
            } else {
                double d = dArr2[i];
                double d2 = this.l1Cost;
                if (d < (-d2)) {
                    dArr3[i] = dArr2[i] + d2;
                } else if (dArr2[i] > d2) {
                    dArr3[i] = dArr2[i] - d2;
                } else {
                    dArr3[i] = 0.0d;
                }
            }
        }
    }

    private void display(String str) {
        System.out.print(str);
    }

    private boolean isConverged(LineSearch.LineSearchResult lineSearchResult) {
        if (lineSearchResult.getFuncChangeRate() < 1.0E-4d) {
            if (this.verbose) {
                display("Function change rate is smaller than the threshold 1.0E-4.\nTraining will stop.\n\n");
            }
            return true;
        }
        if ((this.l1Cost > 0.0d ? opennlp.tools.ml.ArrayMath.l2norm(lineSearchResult.getPseudoGradAtNext()) : opennlp.tools.ml.ArrayMath.l2norm(lineSearchResult.getGradAtNext())) / Math.max(1.0d, opennlp.tools.ml.ArrayMath.l2norm(lineSearchResult.getNextPoint())) < 1.0E-4d) {
            if (this.verbose) {
                display("Relative L2-norm of the gradient is smaller than the threshold 1.0E-4.\nTraining will stop.\n\n");
            }
            return true;
        }
        if (lineSearchResult.getStepSize() < 1.0E-10d) {
            if (this.verbose) {
                display("Step size is smaller than the minimum step size 1.0E-10.\nTraining will stop.\n\n");
            }
            return true;
        }
        if (lineSearchResult.getFctEvalCount() <= this.maxFctEval) {
            return false;
        }
        if (this.verbose) {
            display("Maximum number of function evaluations has exceeded the threshold " + this.maxFctEval + ".\nTraining will stop.\n\n");
        }
        return true;
    }

    public Evaluator getEvaluator() {
        return this.evaluator;
    }

    public double[] minimize(Function function) {
        double[] dArr;
        int i;
        L2RegFunction l2RegFunction = new L2RegFunction(function, this.l2Cost);
        this.dimension = l2RegFunction.getDimension();
        this.updateInfo = new UpdateInfo(this.m, this.dimension);
        double[] dArr2 = new double[this.dimension];
        double valueAt = l2RegFunction.valueAt(dArr2);
        double[] dArr3 = new double[this.dimension];
        int i2 = 0;
        System.arraycopy(l2RegFunction.gradientAt(dArr2), 0, dArr3, 0, this.dimension);
        double d = this.l1Cost;
        if (d > 0.0d) {
            valueAt += d * opennlp.tools.ml.ArrayMath.l1norm(dArr2);
            dArr = new double[this.dimension];
            computePseudoGrad(dArr2, dArr3, dArr);
        } else {
            dArr = null;
        }
        LineSearch.LineSearchResult initialObjectForL1 = this.l1Cost > 0.0d ? LineSearch.LineSearchResult.getInitialObjectForL1(valueAt, dArr3, dArr, dArr2) : LineSearch.LineSearchResult.getInitialObject(valueAt, dArr3, dArr2);
        if (this.verbose) {
            display("\nSolving convex optimization problem.");
            display("\nObjective function has " + this.dimension + " variable(s).");
            display("\n\nPerforming " + this.iterations + " iterations with L1Cost=" + this.l1Cost + " and L2Cost=" + this.l2Cost + "\n");
        }
        double[] dArr4 = new double[this.dimension];
        long currentTimeMillis = System.currentTimeMillis();
        double invL2norm = this.l1Cost > 0.0d ? opennlp.tools.ml.ArrayMath.invL2norm(initialObjectForL1.getPseudoGradAtNext()) : opennlp.tools.ml.ArrayMath.invL2norm(initialObjectForL1.getGradAtNext());
        int i3 = 1;
        while (i3 <= this.iterations) {
            if (this.l1Cost > 0.0d) {
                System.arraycopy(initialObjectForL1.getPseudoGradAtNext(), i2, dArr4, i2, dArr4.length);
            } else {
                System.arraycopy(initialObjectForL1.getGradAtNext(), i2, dArr4, i2, dArr4.length);
            }
            computeDirection(dArr4);
            if (this.l1Cost > 0.0d) {
                double[] pseudoGradAtNext = initialObjectForL1.getPseudoGradAtNext();
                for (int i4 = 0; i4 < this.dimension; i4++) {
                    if (dArr4[i4] * pseudoGradAtNext[i4] >= 0.0d) {
                        dArr4[i4] = 0.0d;
                    }
                }
                i = i3;
                LineSearch.doConstrainedLineSearch(l2RegFunction, dArr4, initialObjectForL1, this.l1Cost, invL2norm);
                computePseudoGrad(initialObjectForL1.getNextPoint(), initialObjectForL1.getGradAtNext(), pseudoGradAtNext);
                initialObjectForL1.setPseudoGradAtNext(pseudoGradAtNext);
            } else {
                i = i3;
                LineSearch.doLineSearch(l2RegFunction, dArr4, initialObjectForL1, invL2norm);
            }
            this.updateInfo.update(initialObjectForL1);
            if (this.verbose) {
                if (i < 10) {
                    display("  " + i + ":  ");
                } else if (i < 100) {
                    display(" " + i + ":  ");
                } else {
                    display(i + ":  ");
                }
                if (this.evaluator != null) {
                    display("\t" + initialObjectForL1.getValueAtNext() + "\t" + initialObjectForL1.getFuncChangeRate() + "\t" + this.evaluator.evaluate(initialObjectForL1.getNextPoint()) + "\n");
                } else {
                    display("\t " + initialObjectForL1.getValueAtNext() + "\t" + initialObjectForL1.getFuncChangeRate() + "\n");
                }
            }
            if (isConverged(initialObjectForL1)) {
                break;
            }
            i3 = i + 1;
            invL2norm = 1.0d;
            i2 = 0;
        }
        if (this.l1Cost > 0.0d && this.l2Cost > 0.0d) {
            double[] nextPoint = initialObjectForL1.getNextPoint();
            for (int i5 = 0; i5 < this.dimension; i5++) {
                nextPoint[i5] = Math.sqrt(this.l2Cost + 1.0d) * nextPoint[i5];
            }
        }
        long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
        StringBuilder sb = new StringBuilder();
        sb.append("Running time: ");
        double d2 = currentTimeMillis2;
        Double.isNaN(d2);
        sb.append(d2 / 1000.0d);
        sb.append("s\n");
        display(sb.toString());
        this.updateInfo = null;
        System.gc();
        double[] dArr5 = new double[this.dimension];
        System.arraycopy(initialObjectForL1.getNextPoint(), 0, dArr5, 0, this.dimension);
        return dArr5;
    }

    public void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }
}
