/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.MaxEnt;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

public class RankMaxEnt
extends MaxEnt {
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public RankMaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection featureSelection, FeatureSelection[] perClassFeatureSelection) {
        super(dataPipe, parameters, featureSelection, perClassFeatureSelection);
    }

    public RankMaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection featureSelection) {
        this(dataPipe, parameters, featureSelection, null);
    }

    public RankMaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection[] perClassFeatureSelection) {
        this(dataPipe, parameters, null, perClassFeatureSelection);
    }

    public RankMaxEnt(Pipe dataPipe, double[] parameters) {
        this(dataPipe, parameters, null, null);
    }

    public void getUnnormalizedClassificationScores(Instance instance, double[] scores) {
        FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
        assert (scores.length == fvs.size());
        int numFeatures = instance.getDataAlphabet().size() + 1;
        for (int instanceNumber = 0; instanceNumber < fvs.size(); ++instanceNumber) {
            FeatureVector fv = fvs.get(instanceNumber);
            assert (fv.getAlphabet() == this.instancePipe.getDataAlphabet());
            scores[instanceNumber] = this.parameters[0 * numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, numFeatures, 0, fv, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[0]);
        }
    }

    public void getClassificationScores(Instance instance, double[] scores) {
        int li;
        FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
        int numFeatures = instance.getDataAlphabet().size() + 1;
        int numLabels = fvs.size();
        assert (scores.length == fvs.size());
        for (int instanceNumber = 0; instanceNumber < fvs.size(); ++instanceNumber) {
            FeatureVector fv = fvs.get(instanceNumber);
            assert (fv.getAlphabet() == this.instancePipe.getDataAlphabet());
            scores[instanceNumber] = this.parameters[0 * numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, numFeatures, 0, fv, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[0]);
        }
        double max = MatrixOps.max(scores);
        double sum = 0.0;
        for (li = 0; li < numLabels; ++li) {
            scores[li] = Math.exp(scores[li] - max);
            sum += scores[li];
        }
        li = 0;
        while (li < numLabels) {
            int n = li++;
            scores[n] = scores[n] / sum;
        }
    }

    public void getClassificationScoresForTies(Instance instance, double[] scores, int[] bestLabels) {
        int li;
        this.getClassificationScores(instance, scores);
        for (int i = 1; i < bestLabels.length; ++i) {
            scores[bestLabels[i]] = 0.0;
        }
        double sum = 0.0;
        for (li = 0; li < scores.length; ++li) {
            sum += scores[li];
        }
        li = 0;
        while (li < scores.length) {
            int n = li++;
            scores[n] = scores[n] / sum;
        }
    }

    public Classification classify(Instance instance) {
        FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
        int numClasses = fvs.size();
        double[] scores = new double[numClasses];
        this.getClassificationScores(instance, scores);
        return new Classification(instance, this, this.createLabelVector(this.getLabelAlphabet(), scores));
    }

    private LabelVector createLabelVector(LabelAlphabet labelAlphabet, double[] scores) {
        int i;
        if (labelAlphabet.growthStopped()) {
            labelAlphabet.startGrowth();
        }
        for (int i2 = 0; i2 < scores.length; ++i2) {
            labelAlphabet.lookupIndex(String.valueOf(i2), true);
        }
        double[] allScores = new double[labelAlphabet.size()];
        for (i = 0; i < labelAlphabet.size(); ++i) {
            allScores[i] = 0.0;
        }
        for (i = 0; i < scores.length; ++i) {
            int index = labelAlphabet.lookupIndex(String.valueOf(i), true);
            allScores[index] = scores[i];
        }
        return new LabelVector(labelAlphabet, allScores);
    }

    public void print() {
        Alphabet dict = this.getAlphabet();
        LabelAlphabet labelDict = this.getLabelAlphabet();
        int numFeatures = dict.size() + 1;
        int numLabels = labelDict.size();
        System.out.println("FEATURES FOR CLASS " + labelDict.lookupObject(0));
        System.out.println(" <default> " + this.parameters[this.defaultFeatureIndex]);
        for (int i = 0; i < this.defaultFeatureIndex; ++i) {
            Object name = dict.lookupObject(i);
            double weight = this.parameters[i];
            System.out.println(" " + name + " " + weight);
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        out.writeInt(1);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        int version = in.readInt();
    }
}

