Extracting Class Probabilities from SparkR ML Clas

2020-07-17 05:18发布

I'm wondering if it's possible (using the built in features of SparkR or any other workaround), to extract the class probabilities of some of the classification algorithms that included in SparkR. Particular ones of interest are.

spark.gbt()
spark.mlp()
spark.randomForest()

Currently, when I use the predict function on these models I am able to extract the predictions, but not the actual probabilities or "confidence."

I've seen several other questions that are similar to this topic, but none that are specific to SparkR, and many have not been answered in regards to Spark's most recent updates.

1条回答
迷人小祖宗
2楼-- · 2020-07-17 05:30

i ran into the same problem, and following this answer now use SparkR:::callJMethod to transform the probability DenseVector (which R cannot deserialize) to an Array (which R reads as a List). It's not very elegant or fast, but it does the job:

  denseVectorToArray <- function(dv) {
    SparkR:::callJMethod(dv, "toArray")
  }

e.g.: start your spark session

#library(SparkR)
#sparkR.session(master = "local") 

generate toy data

data <- data.frame(clicked = base::sample(c(0,1),100,replace=TRUE),
                  someString = base::sample(c("this", "that"),
                                           100, replace=TRUE), 
                  stringsAsFactors=FALSE)

trainidxs <- base::sample(nrow(data), nrow(data)*0.7)
traindf <- as.DataFrame(data[trainidxs,])
testdf <- as.DataFrame(data[-trainidxs,])

train a random forest and run predictions:

rf <- spark.randomForest(traindf, 
                        clicked~., 
                        type = "classification", 
                        maxDepth = 2, 
                        maxBins = 2,
                        numTrees = 100)

predictions <- predict(rf, testdf)

collect your predictions:

collected = SparkR::collect(predictions)    

now extract the probabilities:

collected$probabilities <- lapply(collected$probability, function(x)  denseVectorToArray(x))     
str(probs) 

ofcourse, the function wrapper around SparkR:::callJMethod is a bit of an overkill. You can also use it directly, e.g. with dplyr:

withprobs = collected %>%
            rowwise() %>%
            mutate("probabilities" = list(SparkR:::callJMethod(probability,"toArray"))) %>%
            mutate("prob0" = probabilities[[1]], "prob1" = probabilities[[2]])
查看更多
登录 后发表回答