/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.BoundedInferenceModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class EnsembleInferenceModel
implements InferenceModel,
BoundedInferenceModel {
    public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
    private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
    private static final ConstructingObjectParser<EnsembleInferenceModel, Void> PARSER = new ConstructingObjectParser("ensemble_inference_model", true, a -> new EnsembleInferenceModel((List)a[0], (OutputAggregator)a[1], TargetType.fromString((String)a[2]), (List)a[3], (List)a[4]));
    private String[] featureNames = new String[0];
    private final List<InferenceModel> models;
    private final OutputAggregator outputAggregator;
    private final TargetType targetType;
    private final List<String> classificationLabels;
    private final double[] classificationWeights;
    private volatile boolean preparedForInference = false;
    private final Supplier<double[]> predictedValuesBoundariesSupplier;

    public static EnsembleInferenceModel fromXContent(XContentParser parser) {
        return (EnsembleInferenceModel)PARSER.apply(parser, null);
    }

    private EnsembleInferenceModel(List<InferenceModel> models, OutputAggregator outputAggregator, TargetType targetType, @Nullable List<String> classificationLabels, List<Double> classificationWeights) {
        this.models = ExceptionsHelper.requireNonNull(models, Ensemble.TRAINED_MODELS);
        this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, Ensemble.AGGREGATE_OUTPUT);
        this.targetType = ExceptionsHelper.requireNonNull(targetType, TargetType.TARGET_TYPE);
        this.classificationLabels = classificationLabels;
        this.classificationWeights = classificationWeights == null ? null : classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
        this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
    }

    @Override
    public String[] getFeatureNames() {
        return this.featureNames;
    }

    @Override
    public TargetType targetType() {
        return this.targetType;
    }

    @Override
    public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
        return this.innerInfer(InferenceModel.extractFeatures(this.featureNames, fields), config, featureDecoderMap);
    }

    @Override
    public InferenceResults infer(double[] features, InferenceConfig config) {
        return this.innerInfer(features, config, Collections.emptyMap());
    }

    private InferenceResults innerInfer(double[] features, InferenceConfig config, Map<String, String> featureDecoderMap) {
        if (!config.isTargetTypeSupported(this.targetType)) {
            throw ExceptionsHelper.badRequestException("Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), this.targetType.toString());
        }
        if (!this.preparedForInference) {
            throw ExceptionsHelper.serverError("model is not prepared for inference");
        }
        LOGGER.debug(() -> "Inference called with feature names [" + org.elasticsearch.common.Strings.arrayToCommaDelimitedString((Object[])this.featureNames) + "] values " + Arrays.toString(features));
        double[][] inferenceResults = new double[this.models.size()][];
        double[][] featureInfluence = new double[features.length][];
        int i = 0;
        NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(config.requestingImportance());
        for (InferenceModel model : this.models) {
            InferenceResults result = model.infer(features, subModelInferenceConfig);
            assert (result instanceof RawInferenceResults);
            RawInferenceResults inferenceResult = (RawInferenceResults)result;
            inferenceResults[i++] = inferenceResult.getValue();
            if (!config.requestingImportance()) continue;
            EnsembleInferenceModel.addFeatureImportance(featureInfluence, inferenceResult);
        }
        double[] processed = this.outputAggregator.processValues(inferenceResults);
        return this.buildResults(processed, featureInfluence, featureDecoderMap, config);
    }

    double[][] featureImportance(double[] features) {
        double[][] featureInfluence = new double[features.length][];
        NullInferenceConfig subModelInferenceConfig = new NullInferenceConfig(true);
        for (InferenceModel model : this.models) {
            InferenceResults result = model.infer(features, subModelInferenceConfig);
            assert (result instanceof RawInferenceResults);
            RawInferenceResults inferenceResult = (RawInferenceResults)result;
            EnsembleInferenceModel.addFeatureImportance(featureInfluence, inferenceResult);
        }
        return featureInfluence;
    }

    private static void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
        double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
        assert (modelFeatureImportance.length == featureInfluence.length);
        for (int j = 0; j < modelFeatureImportance.length; ++j) {
            if (featureInfluence[j] == null) {
                featureInfluence[j] = new double[modelFeatureImportance[j].length];
            }
            featureInfluence[j] = InferenceHelpers.sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
        }
    }

    private InferenceResults buildResults(double[] processedInferences, double[][] featureImportance, Map<String, String> featureDecoderMap, InferenceConfig config) {
        if (config instanceof NullInferenceConfig) {
            return new RawInferenceResults(new double[]{this.outputAggregator.aggregate(processedInferences)}, featureImportance);
        }
        Map<String, double[]> decodedFeatureImportance = config.requestingImportance() ? InferenceHelpers.decodeFeatureImportances(featureDecoderMap, IntStream.range(0, featureImportance.length).boxed().collect(Collectors.toMap(i -> this.featureNames[i], i -> featureImportance[i]))) : Collections.emptyMap();
        switch (this.targetType) {
            case REGRESSION: {
                return new RegressionInferenceResults(this.outputAggregator.aggregate(processedInferences), config, InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
            }
            case CLASSIFICATION: {
                ClassificationConfig classificationConfig = (ClassificationConfig)config;
                assert (this.classificationWeights == null || processedInferences.length == this.classificationWeights.length);
                Tuple<InferenceHelpers.TopClassificationValue, List<TopClassEntry>> topClasses = InferenceHelpers.topClasses(processedInferences, this.classificationLabels, this.classificationWeights, classificationConfig.getNumTopClasses(), classificationConfig.getPredictionFieldType());
                InferenceHelpers.TopClassificationValue value = (InferenceHelpers.TopClassificationValue)topClasses.v1();
                return new ClassificationInferenceResults((double)value.getValue(), InferenceHelpers.classificationLabel(((InferenceHelpers.TopClassificationValue)topClasses.v1()).getValue(), this.classificationLabels), (List<TopClassEntry>)((List)topClasses.v2()), InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance, this.classificationLabels, classificationConfig.getPredictionFieldType()), config, (Double)value.getProbability(), (Double)value.getScore());
            }
        }
        throw new UnsupportedOperationException("unsupported target_type [" + String.valueOf((Object)this.targetType) + "] for inference on ensemble model");
    }

    @Override
    public boolean supportsFeatureImportance() {
        return this.models.stream().allMatch(InferenceModel::supportsFeatureImportance);
    }

    @Override
    public String getName() {
        return "ensemble";
    }

    @Override
    public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
        LOGGER.debug(() -> Strings.format((String)"rewriting features %s", (Object[])new Object[]{newFeatureIndexMapping}));
        if (this.preparedForInference) {
            return;
        }
        this.preparedForInference = true;
        HashMap<String, Integer> featureIndexMapping = new HashMap<String, Integer>();
        if (newFeatureIndexMapping == null || newFeatureIndexMapping.isEmpty()) {
            Set<String> referencedFeatures = this.subModelFeatures();
            LOGGER.debug(() -> Strings.format((String)"detected submodel feature names %s", (Object[])new Object[]{referencedFeatures}));
            int newFeatureIndex = 0;
            featureIndexMapping = new HashMap();
            this.featureNames = new String[referencedFeatures.size()];
            for (String featureName : referencedFeatures) {
                featureIndexMapping.put(featureName, newFeatureIndex);
                this.featureNames[newFeatureIndex++] = featureName;
            }
        } else {
            this.featureNames = new String[0];
        }
        for (InferenceModel model : this.models) {
            model.rewriteFeatureIndices(featureIndexMapping);
        }
    }

    private Set<String> subModelFeatures() {
        LinkedHashSet<String> referencedFeatures = new LinkedHashSet<String>();
        for (InferenceModel model : this.models) {
            if (model instanceof EnsembleInferenceModel) {
                EnsembleInferenceModel ensembleInferenceModel = (EnsembleInferenceModel)model;
                referencedFeatures.addAll(ensembleInferenceModel.subModelFeatures());
                continue;
            }
            for (String featureName : model.getFeatureNames()) {
                referencedFeatures.add(featureName);
            }
        }
        return referencedFeatures;
    }

    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOf((String[])this.featureNames);
        size += RamUsageEstimator.sizeOfCollection(this.classificationLabels);
        size += RamUsageEstimator.sizeOfCollection(this.models);
        if (this.classificationWeights != null) {
            size += RamUsageEstimator.sizeOf((double[])this.classificationWeights);
        }
        return size += this.outputAggregator.ramBytesUsed();
    }

    public List<InferenceModel> getModels() {
        return this.models;
    }

    public OutputAggregator getOutputAggregator() {
        return this.outputAggregator;
    }

    public TargetType getTargetType() {
        return this.targetType;
    }

    public double[] getClassificationWeights() {
        return this.classificationWeights;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
        builder.append("featureNames=").append(Arrays.toString(this.featureNames)).append(", models=").append(this.models).append(", outputAggregator=").append(this.outputAggregator).append(", targetType=").append((Object)this.targetType);
        if (this.targetType == TargetType.CLASSIFICATION) {
            builder.append(", classificationLabels=").append(this.classificationLabels).append(", classificationWeights=").append(Arrays.toString(this.classificationWeights));
        } else if (this.targetType == TargetType.REGRESSION) {
            builder.append(", minPredictedValue=").append(this.getMinPredictedValue()).append(", maxPredictedValue=").append(this.getMaxPredictedValue());
        }
        builder.append(", preparedForInference=").append(this.preparedForInference);
        return builder.append('}').toString();
    }

    @Override
    public double getMinPredictedValue() {
        return this.predictedValuesBoundariesSupplier.get()[0];
    }

    @Override
    public double getMaxPredictedValue() {
        return this.predictedValuesBoundariesSupplier.get()[1];
    }

    private double[] initModelBoundaries() {
        double[] modelsMinBoundaries = new double[this.models.size()];
        double[] modelsMaxBoundaries = new double[this.models.size()];
        int i = 0;
        for (InferenceModel model : this.models) {
            if (model instanceof BoundedInferenceModel) {
                BoundedInferenceModel boundedInferenceModel = (BoundedInferenceModel)model;
                modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
                modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
                continue;
            }
            throw new IllegalStateException("All submodels have to be bounded");
        }
        return new double[]{this.outputAggregator.aggregate(modelsMinBoundaries), this.outputAggregator.aggregate(modelsMaxBoundaries)};
    }

    static {
        PARSER.declareNamedObjects(ConstructingObjectParser.constructorArg(), (p, c, n) -> (InferenceModel)p.namedObject(InferenceModel.class, n, null), ensembleBuilder -> {}, Ensemble.TRAINED_MODELS);
        PARSER.declareNamedObject(ConstructingObjectParser.constructorArg(), (p, c, n) -> (LenientlyParsedOutputAggregator)p.namedObject(LenientlyParsedOutputAggregator.class, n, null), Ensemble.AGGREGATE_OUTPUT);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), TargetType.TARGET_TYPE);
        PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), Ensemble.CLASSIFICATION_LABELS);
        PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), Ensemble.CLASSIFICATION_WEIGHTS);
    }
}

