/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.kernel;

import com.oracle.labs.mlrg.olcut.config.ArgumentException;
import com.oracle.labs.mlrg.olcut.config.Option;
import java.util.logging.Logger;
import org.tribuo.classification.ClassificationOptions;
import org.tribuo.classification.sgd.kernel.KernelSVMTrainer;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.kernel.Linear;
import org.tribuo.math.kernel.Polynomial;
import org.tribuo.math.kernel.RBF;
import org.tribuo.math.kernel.Sigmoid;

public class KernelSVMOptions
implements ClassificationOptions<KernelSVMTrainer> {
    private static final Logger logger = Logger.getLogger(KernelSVMOptions.class.getName());
    @Option(longName="kernel-intercept", usage="Intercept in kernel function. Defaults to 1.0.")
    public double kernelIntercept = 1.0;
    @Option(longName="kernel-degree", usage="Degree in polynomial kernel function. Defaults to 1.0.")
    public double kernelDegree = 1.0;
    @Option(longName="kernel-gamma", usage="Gamma value in kernel function. Defaults to 1.0.")
    public double kernelGamma = 1.0;
    @Option(longName="kernel-epochs", usage="Number of SGD epochs. Defaults to 5.")
    public int kernelEpochs = 5;
    @Option(longName="kernel-kernel", usage="Kernel function. Defaults to LINEAR.")
    public KernelEnum kernelKernel = KernelEnum.LINEAR;
    @Option(longName="kernel-lambda", usage="Lambda value in gradient optimisation. Defaults to 0.01.")
    public double kernelLambda = 0.01;
    @Option(longName="kernel-logging-interval", usage="Log the objective after <int> examples. Defaults to 100.")
    public int kernelLoggingInterval = 100;
    @Option(longName="kernel-seed", usage="Sets the random seed for the Kernel SVM.")
    public long kernelSeed = 12345L;

    public KernelSVMTrainer getTrainer() {
        logger.info("Configuring Kernel SVM Trainer");
        Linear kernelObj = null;
        switch (this.kernelKernel) {
            case LINEAR: {
                logger.info("Using a linear kernel");
                kernelObj = new Linear();
                break;
            }
            case POLYNOMIAL: {
                logger.info("Using a Polynomial kernel with gamma " + this.kernelGamma + ", intercept " + this.kernelIntercept + ", and degree " + this.kernelDegree);
                kernelObj = new Polynomial(this.kernelGamma, this.kernelIntercept, this.kernelDegree);
                break;
            }
            case RBF: {
                logger.info("Using an RBF kernel with gamma " + this.kernelGamma);
                kernelObj = new RBF(this.kernelGamma);
                break;
            }
            case SIGMOID: {
                logger.info("Using a tanh kernel with gamma " + this.kernelGamma + ", and intercept " + this.kernelIntercept);
                kernelObj = new Sigmoid(this.kernelGamma, this.kernelIntercept);
                break;
            }
            default: {
                logger.warning("Unknown kernel function " + (Object)((Object)this.kernelKernel));
                throw new ArgumentException("kernel-kernel", "Unknown kernel function " + (Object)((Object)this.kernelKernel));
            }
        }
        logger.info(String.format("Set logging interval to %d", this.kernelLoggingInterval));
        return new KernelSVMTrainer((Kernel)kernelObj, this.kernelLambda, this.kernelEpochs, this.kernelLoggingInterval, this.kernelSeed);
    }

    public static enum KernelEnum {
        LINEAR,
        POLYNOMIAL,
        SIGMOID,
        RBF;

    }
}

