I'm using DecisionTreeClassifier from scikit-learn to classify some multiclass data. I found many posts describing how to display the decision tree path, like here, here, and here. However, all of them describe how to display the tree for the trained data. It makes sense, because export_graphviz
only requires a fitted model.
My question is how do I visualize the tree on the test samples (preferably by export_graphviz
). I.e. after fitting the model with clf.fit(X[train], y[train])
, and then predicting the results for the test data by clf.predict(X[test])
, I want to visualize the decision path used for predicting the samples X[test]
. Is there a way to do that?
Edit:
I see that the path can be printed using decision_path. If there's a way to get a DOT
output as of export_graphviz
to display it, that would be great.
In order to get the path which is taken for a particular sample in a decision tree you could use decision_path
. It returns a sparse matrix with the decision paths for the provided samples.
Those decision paths can then be used to color/label the tree generated via pydot
. This requires overwriting the color and the label (which results in a bit of ugly code).
Notes
decision_path
can take samples from the training set or new values
- you can go wild with the colors and change the color according to the number of samples or whatever other visualization might be needed
Example
In the example below a visited node is colored in green, all other nodes are white.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
# empty all nodes, i.e.set color to white and number of samples to zero
for node in graph.get_node_list():
if node.get_attributes().get('label') is None:
continue
if 'samples = ' in node.get_attributes()['label']:
labels = node.get_attributes()['label'].split('<br/>')
for i, label in enumerate(labels):
if label.startswith('samples = '):
labels[i] = 'samples = 0'
node.set('label', '<br/>'.join(labels))
node.set_fillcolor('white')
samples = iris.data[129:130]
decision_paths = clf.decision_path(samples)
for decision_path in decision_paths:
for n, node_value in enumerate(decision_path.toarray()[0]):
if node_value == 0:
continue
node = graph.get_node(str(n))[0]
node.set_fillcolor('green')
labels = node.get_attributes()['label'].split('<br/>')
for i, label in enumerate(labels):
if label.startswith('samples = '):
labels[i] = 'samples = {}'.format(int(label.split('=')[1]) + 1)
node.set('label', '<br/>'.join(labels))
filename = 'tree.png'
graph.write_png(filename)