Get risk predictions in WEKA using own Java code

2019-08-04 00:31发布

问题:

I already checked the "Making predictions" documentation of WEKA and it contains explicit instructions for command line and GUI predictions.

I want to know how to get a prediction value like the one below I got from the GUI using the Agrawal dataset (weka.datagenerators.classifiers.classification.Agrawal) in my own Java code:

inst#,  actual,     predicted,  error,  prediction
1,      1:0,        2:1,        +,      0.941
2,      1:0,        1:0,        ,       1
3,      1:0,        1:0,        ,       1
4,      1:0,        1:0,        ,       1
5,      1:0,        1:0,        ,       1
6,      1:0,        1:0,        ,       1
7,      1:0,        2:1,        +,      0.941
8,      2:1,        2:1,        ,       0.941
9,      2:1,        2:1,        ,       0.941
10,     2:1,        2:1,        ,       0.941
1,      1:0,        1:0,        ,       1
2,      1:0,        1:0,        ,       1
3,      1:0,        1:0,        ,       1

I can't replicate this result even though it said that:

Java

If you want to perform the classification within your own code, see the classifying instances section of this article, explaining the Weka API in general.

I went to the link and it said:

Classifying instances

In case you have an unlabeled dataset that you want to classify with your newly trained classifier, you can use the following code snippet. It loads the file /some/where/unlabeled.arff, uses the previously built classifier tree to label the instances, and saves the labeled data as /some/where/labeled.arff.

This is not the case I want because I just want the k-fold cross validation predictions on my current dataset modeled.


Update

predictions

public FastVector predictions()

Returns the predictions that have been collected.

Returns:

a reference to the FastVector containing the predictions that have been collected. This should be null if no predictions have been collected.

I found the predictions() method for objects of type Evaluation and by using the code:

Object[] preds = evaluation.predictions().toArray();
for(Object pred : preds) {
    System.out.println(pred);
}

It resulted to:

...
NOM: 0.0 0.0 1.0 0.9466666666666667 0.05333333333333334
NOM: 0.0 0.0 1.0 0.8947368421052632 0.10526315789473684
NOM: 0.0 0.0 1.0 0.9934883720930232 0.0065116279069767444
NOM: 0.0 0.0 1.0 0.9466666666666667 0.05333333333333334
NOM: 0.0 0.0 1.0 0.9912575655682583 0.008742434431741762
NOM: 0.0 0.0 1.0 0.9934883720930232 0.0065116279069767444
...

Is this the same thing as the one above?

回答1:

After deep Google searches (and because the documentation provides minimal help) I finally found the answer.

I hope this explicit answer helps others in the future.

  • For a sample code I saw the question "How to print out the predicted class after cross-validation in WEKA" and I'm glad I was able to decode the incomplete answer wherein some of it is hard to understand.

    Here is my code that worked similar to the GUI's output

    StringBuffer predictionSB = new StringBuffer();
    Range attributesToShow = null;
    Boolean outputDistributions = new Boolean(true);
    
    PlainText predictionOutput = new PlainText();
    predictionOutput.setBuffer(predictionSB);
    predictionOutput.setOutputDistribution(true);
    
    Evaluation evaluation = new Evaluation(data);
    evaluation.crossValidateModel(j48Model, data, numberOfFolds,
            randomNumber, predictionOutput, attributesToShow,
            outputDistributions);
    

    To help you understand, we need to implement the StringBuffer to be casted in an AbstractOutput object so that the function crossValidateModel can recognize it.

    Using StringBuffer only will cause a java.lang.ClassCastException similar the one in the question while using a PlainText without a StringBuffer will show a java.lang.IllegalStateException.

    I would like to thank ManChon U (Kevin) and their question "How to identify the cross-evaluation result to its corresponding instance in the input data set?" for giving me a clue on what this meant:

    ... you just need a single addition argument that is a concrete subclass of weka.classifiers.evaluation.output.prediction.AbstractOutput. weka.classifiers.evaluation.output.prediction.PlainText is probably the one you want to use. Source

    and

    ... Try creating a PlainText object, which extends AbstractOutput (called output for example) instance and calling output.setBuffer(forPredictionsPrinting) and passing that in instead of the buffer. Source

    These just actually meant to create a PlainText object, put a StringBuffer in it and use it to tweak the output with methods setOutput(boolean) and others.

    Finally, to get our desired predictions, just use:

    System.out.println(predictionOutput.getBuffer());
    

    Wherein predictionOutput is an object from the AbstractOutput family (PlainText, CSV, XML, etc).

  • Additionally, the results of evaluation.predictions() is different from the one provided in the WEKA GUI. Fortunately Mark Hall explained this in the question "Print out the predict class after cross-validation"

    Evaluation.predictions() returns a FastVector containing either NominalPrediction or NumericPrediction objects from the weka.classifiers.evaluation package. Calling Evaluation.crossValidateModel() with the additional AbstractOutput object results in the evaluation object printing the prediction/distribution information from Nominal/NumericPrediction objects to the StringBuffer in the format that you see in the Explorer or from the command line.

References:

  • "Print out the predict class after cross-validation"
  • "How to identify the cross-evaluation result to its corresponding instance in the input data set?"
  • "How to print out the predicted class after cross-validation in WEKA"