/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;

public class NonlinearGaussianDataSource
implements ConfigurableDataSource<Regressor> {
    @Config(mandatory=true, description="The number of samples to draw.")
    private int numSamples;
    @Config(description="The feature weights. Must be a 4 element array.")
    private float[] weights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
    @Config(description="The y-intercept of the line.")
    private float intercept = 0.0f;
    @Config(description="The variance of the noise gaussian.")
    private float variance = 1.0f;
    @Config(description="The minimum value of x_0.")
    private float xZeroMin = -2.0f;
    @Config(description="The maximum value of x_0.")
    private float xZeroMax = 2.0f;
    @Config(description="The minimum value of x_1.")
    private float xOneMin = -2.0f;
    @Config(description="The maximum value of x_1.")
    private float xOneMax = 2.0f;
    @Config(description="The RNG seed.")
    private long seed = 12345L;
    private List<Example<Regressor>> examples;
    private final RegressionFactory factory = new RegressionFactory();
    private static final String[] featureNames = new String[]{"X_0", "X_1"};

    private NonlinearGaussianDataSource() {
    }

    public NonlinearGaussianDataSource(int numSamples, long seed) {
        if (numSamples < 0) {
            throw new IllegalArgumentException("Invalid number of sample specified, must be a positive integer, found " + numSamples);
        }
        this.numSamples = numSamples;
        this.seed = seed;
        this.postConfig();
    }

    public NonlinearGaussianDataSource(int numSamples, float[] weights, float intercept, float variance, float xZeroMin, float xZeroMax, float xOneMin, float xOneMax, long seed) {
        this.numSamples = numSamples;
        this.weights = weights;
        this.intercept = intercept;
        this.variance = variance;
        this.xZeroMin = xZeroMin;
        this.xZeroMax = xZeroMax;
        this.xOneMin = xOneMin;
        this.xOneMax = xOneMax;
        this.seed = seed;
        this.postConfig();
    }

    public void postConfig() {
        Random rng = new Random(this.seed);
        if (this.weights.length != 4) {
            throw new PropertyException("", "weights", "Must supply 4 weights, found " + this.weights.length);
        }
        if (this.xZeroMax <= this.xZeroMin) {
            throw new PropertyException("", "xZeroMax", "xZeroMax must be greater than xZeroMin, found xZeroMax = " + this.xZeroMax + ", xZeroMin = " + this.xZeroMin);
        }
        if (this.xOneMax <= this.xOneMin) {
            throw new PropertyException("", "xOneMax", "xOneMax must be greater than xOneMin, found xOneMax = " + this.xOneMax + ", xOneMin = " + this.xOneMin);
        }
        if ((double)this.variance <= 0.0) {
            throw new PropertyException("", "variance", "Variance must be positive, found variance = " + this.variance);
        }
        ArrayList<ArrayExample> examples = new ArrayList<ArrayExample>(this.numSamples);
        double zeroRange = this.xZeroMax - this.xZeroMin;
        double oneRange = this.xOneMax - this.xOneMin;
        for (int i = 0; i < this.numSamples; ++i) {
            double xZero = rng.nextDouble() * zeroRange + (double)this.xZeroMin;
            double xOne = rng.nextDouble() * oneRange + (double)this.xOneMin;
            double outputValue = (double)this.weights[0] * xZero + (double)this.weights[1] * xOne + (double)this.weights[2] * xZero * xOne + (double)this.weights[3] * Math.pow(xOne, 3.0) + (double)this.intercept;
            Regressor output = new Regressor("Y", rng.nextGaussian() * (double)this.variance + outputValue);
            ArrayExample e = new ArrayExample((Output)output, featureNames, new double[]{xZero, xOne});
            examples.add(e);
        }
        this.examples = Collections.unmodifiableList(examples);
    }

    public OutputFactory<Regressor> getOutputFactory() {
        return this.factory;
    }

    public DataSourceProvenance getProvenance() {
        return new NonlinearGaussianDataSourceProvenance(this);
    }

    public Iterator<Example<Regressor>> iterator() {
        return this.examples.iterator();
    }

    public static MutableDataset<Regressor> generateDataset(int numSamples, float[] weights, float intercept, float variance, float xZeroMin, float xZeroMax, float xOneMin, float xOneMax, long seed) {
        NonlinearGaussianDataSource source = new NonlinearGaussianDataSource(numSamples, weights, intercept, variance, xZeroMin, xZeroMax, xOneMin, xOneMax, seed);
        return new MutableDataset((DataSource)source);
    }

    public static class NonlinearGaussianDataSourceProvenance
    extends SkeletalConfiguredObjectProvenance
    implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1L;

        NonlinearGaussianDataSourceProvenance(NonlinearGaussianDataSource host) {
            super((Configurable)host, "DataSource");
        }

        public NonlinearGaussianDataSourceProvenance(Map<String, Provenance> map) {
            this(NonlinearGaussianDataSourceProvenance.extractProvenanceInfo(map));
        }

        private NonlinearGaussianDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
            super(info);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap<String, Provenance> configuredParameters = new HashMap<String, Provenance>(map);
            String className = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"class-name", StringProvenance.class, (String)NonlinearGaussianDataSourceProvenance.class.getSimpleName())).getValue();
            String hostTypeStringName = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"host-short-name", StringProvenance.class, (String)NonlinearGaussianDataSourceProvenance.class.getSimpleName())).getValue();
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
        }
    }
}

