I could manage to display total probabilities of my labels
, for example after displaying my decision tree, I have a table :
Total Predictions :
65% impressions
30% clicks
5% conversions
But my issue is to find probabilities (or to count) by features
(by node), for example :
if feature1 > 5
if feature2 < 10
Predict Impressions
samples : 30 Impressions
else feature2 >= 10
Predict Clicks
samples : 5 Clicks
Scikit
does it automatically , I am trying to find a way to do it with Spark
Note: the following solution is for Scala only. I didn't find a way to do it in Python.
Assuming you just want a visual representation of the tree as in your example, maybe one option is to adapt the method subtreeToString
present in the Node.scala
code on Spark's GitHub to include the probabilities at each node split, like in the following snippet:
def subtreeToString(rootNode: Node, indentFactor: Int = 0): String = {
def splitToString(split: Split, left: Boolean): String = {
split.featureType match {
case Continuous => if (left) {
s"(feature ${split.feature} <= ${split.threshold})"
} else {
s"(feature ${split.feature} > ${split.threshold})"
}
case Categorical => if (left) {
s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})"
} else {
s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})"
}
}
}
val prefix: String = " " * indentFactor
if (rootNode.isLeaf) {
prefix + s"Predict: ${rootNode.predict.predict} \n"
} else {
val prob = rootNode.predict.prob*100D
prefix + s"If ${splitToString(rootNode.split.get, left = true)} " + f"(Prob: $prob%04.2f %%)" + "\n" +
subtreeToString(rootNode.leftNode.get, indentFactor + 1) +
prefix + s"Else ${splitToString(rootNode.split.get, left = false)} " + f"(Prob: ${100-prob}%04.2f %%)" + "\n" +
subtreeToString(rootNode.rightNode.get, indentFactor + 1)
}
}
I've tested it on a model I run on the Iris dataset, and I've got the following result:
scala> println(subtreeToString(model.topNode))
If (feature 2 <= -0.762712) (Prob: 35.35 %)
Predict: 1.0
Else (feature 2 > -0.762712) (Prob: 64.65 %)
If (feature 3 <= 0.333333) (Prob: 52.24 %)
If (feature 0 <= -0.666667) (Prob: 92.11 %)
Predict: 3.0
Else (feature 0 > -0.666667) (Prob: 7.89 %)
If (feature 2 <= 0.322034) (Prob: 94.59 %)
Predict: 2.0
Else (feature 2 > 0.322034) (Prob: 5.41 %)
If (feature 3 <= 0.166667) (Prob: 50.00 %)
Predict: 3.0
Else (feature 3 > 0.166667) (Prob: 50.00 %)
Predict: 2.0
Else (feature 3 > 0.333333) (Prob: 47.76 %)
Predict: 3.0
A similar approach could be used for creating a tree structure with this information. The main difference would be to store the printed information (split.feature
, split.threshold
, predict.prob
, and so on) as vals and use them to build the structure.