/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.LinearDecaySGD;
import org.tribuo.math.optimisers.SimpleSGD;
import org.tribuo.math.optimisers.SqrtDecaySGD;

public abstract class SGD
implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(SGD.class.getName());
    @Config(mandatory=true, description="Initial learning rate.")
    protected double initialLearningRate;
    @Config(mandatory=true, description="Momentum type to use.")
    protected Momentum useMomentum;
    @Config(description="Momentum scaling factor.")
    protected double rho = 0.0;
    protected int iteration = 0;
    private Tensor[] momentum;

    SGD(double learningRate) {
        this(learningRate, 0.0, Momentum.NONE);
    }

    SGD(double learningRate, double rho, Momentum useMomentum) {
        this.initialLearningRate = learningRate;
        this.useMomentum = useMomentum;
        this.rho = rho;
    }

    protected SGD() {
    }

    @Override
    public void initialise(Parameters parameters) {
        if (this.useMomentum != Momentum.NONE) {
            this.momentum = parameters.getEmptyCopy();
        }
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        ++this.iteration;
        double learningRate = this.learningRate();
        DoubleUnaryOperator learningRateFunc = a -> a * learningRate * weight;
        DoubleUnaryOperator nesterovFunc = a -> a * learningRate * weight * this.rho;
        block4: for (int i = 0; i < updates.length; ++i) {
            switch (this.useMomentum) {
                case STANDARD: {
                    this.momentum[i].scaleInPlace(this.rho);
                    this.momentum[i].intersectAndAddInPlace(updates[i]);
                    updates[i].scaleInPlace(0.0);
                    updates[i].intersectAndAddInPlace(this.momentum[i], learningRateFunc);
                    continue block4;
                }
                case NESTEROV: {
                    this.momentum[i].scaleInPlace(this.rho);
                    this.momentum[i].intersectAndAddInPlace(updates[i]);
                    updates[i].scaleInPlace(weight * learningRate);
                    updates[i].intersectAndAddInPlace(this.momentum[i], nesterovFunc);
                    continue block4;
                }
                default: {
                    updates[i].scaleInPlace(weight * learningRate);
                }
            }
        }
        return updates;
    }

    public abstract double learningRate();

    protected abstract String sgdType();

    public String toString() {
        switch (this.useMomentum) {
            case STANDARD: {
                return "SGD+Momentum(type=" + this.sgdType() + ",initialLearningRate=" + this.initialLearningRate + ",rho=" + this.rho + ")";
            }
            case NESTEROV: {
                return "SGD+NesterovMomentum(type=" + this.sgdType() + ",initialLearningRate=" + this.initialLearningRate + ",rho=" + this.rho + ")";
            }
        }
        return "SGD(type=" + this.sgdType() + ",initialLearningRate=" + this.initialLearningRate + ")";
    }

    @Override
    public void reset() {
        this.momentum = null;
        this.iteration = 0;
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }

    public static SGD getSimpleSGD(double learningRate) {
        return new SimpleSGD(learningRate);
    }

    public static SGD getSimpleSGD(double learningRate, double rho, Momentum momentumType) {
        return new SimpleSGD(learningRate, rho, momentumType);
    }

    public static SGD getLinearDecaySGD(double learningRate) {
        return new LinearDecaySGD(learningRate);
    }

    public static SGD getLinearDecaySGD(double learningRate, double rho, Momentum momentumType) {
        return new LinearDecaySGD(learningRate, rho, momentumType);
    }

    public static SGD getSqrtDecaySGD(double learningRate) {
        return new SqrtDecaySGD(learningRate);
    }

    public static SGD getSqrtDecaySGD(double learningRate, double rho, Momentum momentumType) {
        return new SqrtDecaySGD(learningRate, rho, momentumType);
    }

    public static enum Momentum {
        NONE,
        STANDARD,
        NESTEROV;

    }
}

