How to view Random Forest statistics in Spark (sca

2019-03-04 12:57发布

问题:

I have a RandomForestClassifierModel in Spark. Using .toDebugString() outputs the following

Tree 0 (weight 1.0):
  If (feature 0 in {1.0,2.0,3.0})
   If (feature 3 in {2.0,3.0})
    If (feature 8 <= 55.3) 
.
.
  Else (feature 0 not in {1.0,2.0,3.0})
.
.
Tree 1 (weight 1.0):
.
.
...etc

I'd like to view the actual data as it goes through the model, something like

Tree 0 (weight 1.0):
  If (feature 0 in {1.0,2.0,3.0}) 60%
   If (feature 3 in {2.0,3.0}) 57%
    If (feature 8 <= 55.3) 22%
.
.
  Else (feature 0 not in {1.0,2.0,3.0}) 40%
.
.
Tree 1 (weight 1.0):
.
...etc

By seeing the probability of labels in each node, I can see which paths are most likely to be followed in the trees by the data (thousands of records), which would be really good insight!

I found an awesome answer here: Spark MLib Decision Trees: Probability of labels by features?

Unfortunately the method in the answer uses the MLlib API, and after lots of trying, I have failed to replicate it using the DataFrame API, which has different implementations of the classes Node and Split :(

回答1:

One way I found useful yesterday was I could use spark.read.parquet() function to read the output from the model/data file. This way all information about a certain node could be retrieved as a whole dataframe.

`val modelPath = "some/path/to/your/model"
val dataPath = modelPath + "/data"    
val nodeData: DataFrame = spark.read.parquet(dataPath)
nodeData.show(500,false)
nodeData.printSchema()`

Then you can rebuild the tree with information. Hope it helps.