Variation on “How to plot decision boundary of a k

2019-01-24 11:48发布

问题:

This is a question related to https://stats.stackexchange.com/questions/21572/how-to-plot-decision-boundary-of-a-k-nearest-neighbor-classifier-from-elements-o

For completeness, here's the original example from that link:

library(ElemStatLearn)
require(class)
x <- mixture.example$x
g <- mixture.example$y
xnew <- mixture.example$xnew
mod15 <- knn(x, xnew, g, k=15, prob=TRUE)
prob <- attr(mod15, "prob")
prob <- ifelse(mod15=="1", prob, 1-prob)
px1 <- mixture.example$px1
px2 <- mixture.example$px2
prob15 <- matrix(prob, length(px1), length(px2))
par(mar=rep(2,4))
contour(px1, px2, prob15, levels=0.5, labels="", xlab="", ylab="", main=
        "15-nearest neighbour", axes=FALSE)
points(x, col=ifelse(g==1, "coral", "cornflowerblue"))
gd <- expand.grid(x=px1, y=px2)
points(gd, pch=".", cex=1.2, col=ifelse(prob15>0.5, "coral", "cornflowerblue"))
box()

I've been playing with that example, and would like to try to make it work with three classes. I can change some values of g with something like

g[8:16] <- 2

just to pretend that there are some samples which are from a third class. I can't make the plot work, though. I guess I need to change the lines that deal with the proportion of votes for winning class:

prob <- attr(mod15, "prob")
prob <- ifelse(mod15=="1", prob, 1-prob)

and also the levels on the contour:

contour(px1, px2, prob15, levels=0.5, labels="", xlab="", ylab="", main=
"15-nearest neighbour", axes=FALSE)

I am also not sure contour is the right tool for this. One alternative that works is to create a matrix of data that covers the region I'm interested, classify each point of this matrix and plot those with a large marker and different colors, similar to what is being done with the points(gd...) bit.

The final purpose is to be able to show different decision boundaries generated by different classifiers. Can someone point me to the right direction?

thanks Rafael

回答1:

Separating the main parts in the code will help outlining how to achieve this:

Test data with 3 classes

 train <- rbind(iris3[1:25,1:2,1],
                iris3[1:25,1:2,2],
                iris3[1:25,1:2,3])
 cl <- factor(c(rep("s",25), rep("c",25), rep("v",25)))

Test data covering a grid

 require(MASS)

 test <- expand.grid(x=seq(min(train[,1]-1), max(train[,1]+1),
                           by=0.1),
                     y=seq(min(train[,2]-1), max(train[,2]+1), 
                           by=0.1))

Classification for that grid

3 classes obviously

 require(class)
 classif <- knn(train, test, cl, k = 3, prob=TRUE)
 prob <- attr(classif, "prob")

Data structure for plotting

 require(dplyr)

 dataf <- bind_rows(mutate(test,
                           prob=prob,
                           cls="c",
                           prob_cls=ifelse(classif==cls,
                                           1, 0)),
                    mutate(test,
                           prob=prob,
                           cls="v",
                           prob_cls=ifelse(classif==cls,
                                           1, 0)),
                    mutate(test,
                           prob=prob,
                           cls="s",
                           prob_cls=ifelse(classif==cls,
                                           1, 0)))

Plot

 require(ggplot2)
 ggplot(dataf) +
    geom_point(aes(x=x, y=y, col=cls),
               data = mutate(test, cls=classif),
               size=1.2) + 
    geom_contour(aes(x=x, y=y, z=prob_cls, group=cls, color=cls),
                 bins=2,
                 data=dataf) +
    geom_point(aes(x=x, y=y, col=cls),
               size=3,
               data=data.frame(x=train[,1], y=train[,2], cls=cl))

We can also be a little fancier and plot the probability of class membership as a indication of the "confidence".

 ggplot(dataf) +
    geom_point(aes(x=x, y=y, col=cls, size=prob),
               data = mutate(test, cls=classif)) + 
    scale_size(range=c(0.8, 2)) +
    geom_contour(aes(x=x, y=y, z=prob_cls, group=cls, color=cls),
                 bins=2,
                 data=dataf) +
    geom_point(aes(x=x, y=y, col=cls),
               size=3,
               data=data.frame(x=train[,1], y=train[,2], cls=cl)) +
    geom_point(aes(x=x, y=y),
               size=3, shape=1,
               data=data.frame(x=train[,1], y=train[,2], cls=cl))