/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.stats;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.AccuracyStats;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.Scorer;
import edu.stanford.nlp.util.BinaryHeapPriorityQueue;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.text.NumberFormat;
import java.util.List;

public class MultiClassAccuracyStats<L>
implements Scorer<L> {
    double[] scores;
    boolean[] isCorrect;
    double logLikelihood;
    double accuracy;
    static String saveFile = null;
    static int saveIndex = 1;
    public static final int USE_ACCURACY = 1;
    public static final int USE_LOGLIKELIHOOD = 2;
    private int scoreType = 1;
    int correct = 0;
    int total = 0;

    public MultiClassAccuracyStats() {
    }

    public MultiClassAccuracyStats(int scoreType) {
        this.scoreType = scoreType;
    }

    public MultiClassAccuracyStats(String file) {
        this(file, 1);
    }

    public MultiClassAccuracyStats(String file, int scoreType) {
        saveFile = file;
        this.scoreType = scoreType;
    }

    public <F> MultiClassAccuracyStats(ProbabilisticClassifier<L, F> classifier, GeneralDataset<L, F> data, String file) {
        this(classifier, data, file, 1);
    }

    public <F> MultiClassAccuracyStats(ProbabilisticClassifier<L, F> classifier, GeneralDataset<L, F> data, String file, int scoreType) {
        saveFile = file;
        this.scoreType = scoreType;
        this.initMC(classifier, data);
    }

    @Override
    public <F> double score(ProbabilisticClassifier<L, F> classifier, GeneralDataset<L, F> data) {
        this.initMC(classifier, data);
        return this.score();
    }

    public double score() {
        if (this.scoreType == 1) {
            return this.accuracy;
        }
        if (this.scoreType == 2) {
            return this.logLikelihood;
        }
        throw new RuntimeException("Unknown score type: " + this.scoreType);
    }

    public int numSamples() {
        return this.scores.length;
    }

    public double confidenceWeightedAccuracy() {
        double acc = 0.0;
        for (int recall = 1; recall <= this.numSamples(); ++recall) {
            acc += (double)this.numCorrect(recall) / (double)recall;
        }
        return acc / (double)this.numSamples();
    }

    public <F> void initMC(ProbabilisticClassifier<L, F> classifier, GeneralDataset<L, F> data) {
        BinaryHeapPriorityQueue<Pair<Integer, Pair<Double, Boolean>>> q = new BinaryHeapPriorityQueue<Pair<Integer, Pair<Double, Boolean>>>();
        this.total = 0;
        this.correct = 0;
        this.logLikelihood = 0.0;
        for (int i = 0; i < data.size(); ++i) {
            RVFDatum<L, F> d = data.getRVFDatum(i);
            Counter<L> scores = classifier.logProbabilityOf(d);
            L guess = Counters.argmax(scores);
            Object correctLab = d.label();
            double guessScore = scores.getCount(guess);
            double correctScore = scores.getCount(correctLab);
            int guessInd = data.labelIndex().indexOf(guess);
            int correctInd = data.labelIndex().indexOf(correctLab);
            ++this.total;
            if (guessInd == correctInd) {
                ++this.correct;
            }
            this.logLikelihood += correctScore;
            q.add(new Pair<Integer, Pair<Double, Boolean>>(i, new Pair<Double, Boolean>(new Double(guessScore), guessInd == correctInd)), -guessScore);
        }
        this.accuracy = (double)this.correct / (double)this.total;
        List sorted = q.toSortedList();
        this.scores = new double[sorted.size()];
        this.isCorrect = new boolean[sorted.size()];
        for (int i = 0; i < sorted.size(); ++i) {
            Pair next = (Pair)((Pair)sorted.get(i)).second();
            this.scores[i] = (Double)next.first();
            this.isCorrect[i] = (Boolean)next.second();
        }
    }

    public int numCorrect(int recall) {
        int correct = 0;
        for (int j = this.scores.length - 1; j >= this.scores.length - recall; --j) {
            if (!this.isCorrect[j]) continue;
            ++correct;
        }
        return correct;
    }

    public int[] getAccCoverage() {
        int[] arr = new int[this.numSamples()];
        for (int recall = 1; recall <= this.numSamples(); ++recall) {
            arr[recall - 1] = this.numCorrect(recall);
        }
        return arr;
    }

    @Override
    public String getDescription(int numDigits) {
        NumberFormat nf = NumberFormat.getNumberInstance();
        nf.setMaximumFractionDigits(numDigits);
        StringBuilder sb = new StringBuilder();
        double confWeightedAccuracy = this.confidenceWeightedAccuracy();
        sb.append("--- Accuracy Stats ---").append("\n");
        sb.append("accuracy: ").append(nf.format(this.accuracy)).append(" (").append(this.correct).append("/").append(this.total).append(")\n");
        sb.append("confidence weighted accuracy :").append(nf.format(confWeightedAccuracy)).append("\n");
        sb.append("log-likelihood: ").append(this.logLikelihood).append("\n");
        if (saveFile != null) {
            String f = saveFile + "-" + saveIndex;
            sb.append("saving accuracy info to ").append(f).append(".accuracy\n");
            StringUtils.printToFile(f + ".accuracy", AccuracyStats.toStringArr(this.getAccCoverage()));
            ++saveIndex;
        }
        return sb.toString();
    }

    public String toString() {
        String accuracyType = null;
        accuracyType = this.scoreType == 1 ? "classification_accuracy" : (this.scoreType == 2 ? "log_likelihood" : "unknown");
        return "MultiClassAccuracyStats(" + accuracyType + ")" + this.scoreType + 1 + 2;
    }
}

