I am computing my confusion matrix as shown below for image semantic segmentation which is a pretty verbose approach:
def confusion_matrix(preds, labels, conf_m, sample_size):
preds = normalize(preds,0.9) # returns [0,1] tensor
preds = preds.flatten()
labels = labels.flatten()
for i in range(len(preds)):
if preds[i]==1 and labels[i]==1:
conf_m[0,0] += 1/(len(preds)*sample_size) # TP
elif preds[i]==1 and labels[i]==0:
conf_m[0,1] += 1/(len(preds)*sample_size) # FP
elif preds[i]==0 and labels[i]==0:
conf_m[1,0] += 1/(len(preds)*sample_size) # TN
elif preds[i]==0 and labels[i]==1:
conf_m[1,1] += 1/(len(preds)*sample_size) # FN
return conf_m
In the prediction loop:
conf_m = torch.zeros(2,2) # two classes (object or no-object)
for img,label in enumerate(data):
...
out = Net(img)
conf_m = confusion_matrix(out, label, len(data))
...
Is there a faster approach (in PyTorch) to efficiently calculate the confusion matrix for a sample of inputs for image semantic segmentation?
I use these 2 functions to calc confusion matrix (as it defined in sklearn):
# rewrite sklearn method to torch
def confusion_matrix_1(y_true, y_pred):
N = max(max(y_true), max(y_pred)) + 1
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
return torch.sparse.LongTensor(
torch.stack([y_true, y_pred]),
torch.ones_like(y_true, dtype=torch.long),
torch.Size([N, N])).to_dense()
# weird trick with bincount
def confusion_matrix_2(y_true, y_pred):
N = max(max(y_true), max(y_pred)) + 1
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
y = N * y_true + y_pred
y = torch.bincount(y)
if len(y) < N * N:
y = torch.cat(y, torch.zeros(N * N - len(y), dtype=torch.long))
y = y.reshape(N, N)
return T
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix_1(y_true, y_pred)
# tensor([[2, 0, 0],
# [0, 0, 1],
# [1, 0, 2]])
Second function is faster in case of small number of classes.
%%timeit
confusion_matrix_1(y_true, y_pred)
# 102 µs ± 30.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
confusion_matrix_2(y_true, y_pred)
# 25 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Thanks to Grigory Feldman for the answer! I had to change a few things to work with my implementation.
For future lookers, here is my final function that sums up the percentages of each confusion matrix over a batch of inputs (to be used in training or testing loops)
def confusion_matrix_2(y_true, y_pred, sample_sz, conf_m):
y_pred = normalize(y_pred,0.9)
obj = y_true[y_true==1]
no_obj = y_true[y_true==0]
N = torch.tensor(torch.max(torch.max(y_true), torch.max(y_pred)) + 1,dtype=torch.int)
y_true = torch.tensor(y_true, dtype=torch.long)
y_pred = torch.tensor(y_pred, dtype=torch.long)
y = N * y_true + y_pred
y = torch.bincount(y.flatten())
if len(y) < N * N:
y = torch.cat((y, torch.zeros(N * N - len(y), dtype=torch.long)))
y = y.reshape(N.item(), N.item())
y = y.float()
conf_m[0,:] += y[0,:]/(len(no_obj)*sample_sz)
conf_m[1,:] += y[1,:]/(len(obj)*sample_sz)
return conf_m
...
conf_m = torch.zeros((2, 2),dtype=torch.float) # two classes (object or no-object)
for _, data in enumerate(dataloader):
for img,label in enumerate(data):
...
out = Net(img)
conf_m = confusion_matrix(out, label, len(data))
...
...