Maximum likelihood pixel classification in python

2019-09-21 01:14发布

I would like to perform pixel classification on RGB images based on input training samples of given number of classes. So I have e.g. 4 classes containing pixels (r,g,b) thus the goal is to segment the image into four phases.

I found that python opencv2 has the Expectation maximization algorithm which could do the job. But unfortunately I did not find any tutorial or material which can explain me (since I am beginner) how to work with the algorithm.

Could you please propose any kind of tutorial which can be used as starting point?

Update...another approach for the code below:

  **def getsamples(img):
    x, y, z = img.shape
    samples = np.empty([x * y, z])
    index = 0
    for i in range(x):
        for j in range(y):
             samples[index] = img[i, j]
             index += 1
    return samples
def EMSegmentation(img, no_of_clusters=2):
    output = img.copy()
    colors = np.array([[0, 11, 111], [22, 22, 22]])
    samples = getsamples(img)
    #em = cv2.ml.EM_create()
    em = cv2.EM(no_of_clusters)
    #em.setClustersNumber(no_of_clusters)
    #em.trainEM(samples)
    em.train(samples)
    x, y, z = img.shape
    index = 0
    for i in range(x):
        for j in range(y):

            result = em.predict(samples[index])[0][1]
            #print(result)
            output[i][j] = colors[result]
            index = index + 1
    return output
img = cv2.imread('00.jpg')
smallImg = small = cv2.resize(img, (0,0), fx=0.5, fy=0.5) 
output = EMSegmentation(img)
smallOutput = cv2.resize(output, (0,0), fx=0.5, fy=0.5) 
cv2.imshow('image', smallImg)
cv2.imshow('EM', smallOutput)
cv2.waitKey(0)
cv2.destroyAllWindows()**

1条回答
爷、活的狠高调
2楼-- · 2019-09-21 02:10

convert C++ to python source

enter image description here

import cv2
import numpy as np


def getsamples(img):
    x, y, z = img.shape
    samples = np.empty([x * y, z])
    index = 0
    for i in range(x):
        for j in range(y):
            samples[index] = img[i, j]
            index += 1
    return samples


def EMSegmentation(img, no_of_clusters=2):
    output = img.copy()
    colors = np.array([[0, 11, 111], [22, 22, 22]])
    samples = getsamples(img)
    em = cv2.ml.EM_create()
    em.setClustersNumber(no_of_clusters)
    em.trainEM(samples)
    means = em.getMeans()
    covs = em.getCovs()  # Known bug: https://github.com/opencv/opencv/pull/4232
    x, y, z = img.shape
    distance = [0] * no_of_clusters
    for i in range(x):
        for j in range(y):
            for k in range(no_of_clusters):
                diff = img[i, j] - means[k]
                distance[k] = abs(np.dot(np.dot(diff, covs[k]), diff.T))
            output[i][j] = colors[distance.index(max(distance))]
    return output


img = cv2.imread('dinosaur.jpg')
output = EMSegmentation(img)
cv2.imshow('image', img)
cv2.imshow('EM', output)
cv2.waitKey(0)
cv2.destroyAllWindows()
查看更多
登录 后发表回答