How to vectorize searching function in Matlab?

2019-08-09 04:39发布

问题:

Here is a Matlab coding problem (A little different version with intersect not setdiff here:

a rating matrix A with 3 cols, the 1st col is user'ID which maybe duplicated, 2nd col is the item'ID which maybe duplicated, 3rd col is rating from user to item, ranging from 1 to 5.

Now, I have a subset of user IDs smallUserIDList and a subset of item IDs smallItemIDList, then I want to find the rows in A that rated by users in smallUserIDList, and collect the items that user rated, and do some calculations, such as setdiff with smallItemIDList and count the result, as the following code does:

userStat = zeros(length(smallUserIDList), 1);
for i = 1:length(smallUserIDList)
    A2= A(A(:,1) == smallUserIDList(i), :);
    itemIDList_each = unique(A2(:,2));

    setDiff = setdiff(itemIDList_each , smallItemIDList);
    userStat(i) = length(setDiff);
end
userStat

Finally, I find the profile viewer showing that the loop above is inefficient, the question is how to improve this piece of code with vectorization but the help of for loop?

For example:

Input:

A = [
1 11 1
2 22 2
2 66 4
4 44 5
6 66 5
7 11 5
7 77 5
8 11 2
8 22 3
8 44 3
8 66 4
8 77 5    
]

smallUserIDList = [1 2 7 8]
smallItemIDList = [11 22 33 55 77]

Output:

userStat =

 0
 1
 0
 2

回答1:

Vanilla MATLAB:

As far as I can tell your code is equivalent to:

%// Create matrix such that: user_item_rating(user,item)==rating
user_item_rating = sparse(A(:,1),A(:,2),A(:,3));

%// Keep all BUT the items in smallItemIDList
user_item_rating(:,smallItemIDList) = [];

%// Keep only those users in `smallUserIDList` and use order of this list
user_item_rating = user_item_rating(smallUserIDList,:);

%// Count the number of ratings
userStat = sum(user_item_rating~=0, 2);

This will work if there is at most one rating per (user,item)-combination. Also it should be quite efficient.

Clean approach without reinventing the wheel:

Check out grpstats from the Statistics Toolbox! An implementation could look similar to this:

%// Create ratings table
ratings = array2table(A, 'VariableNames', {'user','item','rating'});

%// Remove items we don't care about (smallItemIDList)
ratings = ratings(~ismember(ratings.item, smallItemIDList),:);

%// Keep only users we care about (smallUserIDList) 
ratings = ratings(ismember(ratings.user, smallUserIDList),:);

%// Compute the statistics grouped by 'user'. 
userStat = grpstats(ratings, 'user');


回答2:

This could be one vectorized approach -

%// Take care of equality between first column of A and smallUserIDList to 
%// find the matching row and column indices.
%// NOTE: This corresponds to "A(:,1) == smallUserIDList(i)" from OP.
[R,C] = find(bsxfun(@eq,A(:,1),smallUserIDList.')); %//'

%// Take care of non-equality between second column of A and smallItemIDList. 
%// NOTE: This corresponds to SETDIFF in the original loopy code from OP.
mask1 = ~ismember(A(R,2),smallItemIDList);

AR2 = A(R,2); %// Elements from 2nd col of A that has matches from first step

%// Get only those elements from C and AR2 that has ONES in mask1
C1 = C(mask1);
AR2 = AR2(mask1);

%// Initialized output array
userStat = zeros(numel(smallUserIDList),1);

if ~isempty(C1)%//There is at least one element in C, so do further processing

    %// Find the count of duplicate elements for each ID in C1 indexed into AR2.
    %// NOTE: This corresponds to "unique(A2(:,2))" from OP.
    dup_counts = accumarray(C1,AR2,[],@(x) numel(x)-numel(unique(x)));

    %// Get the count of matches for each ID in C in the mask1.
    %// NOTE: This corresponds to:
    %//       "length(setdiff(itemIDList_each , smallItemIDList))" from OP.
    accums = accumarray(C,mask1);

    %// Store the counts in output array and also subtract the dup counts
    userStat(1:numel(accums)) = accums;
    userStat(1:numel(dup_counts)) = userStat(1:numel(dup_counts)) - dup_counts;
end

Benchmarking

The code listed next compares runtimes for proposed approach against the original loopy code -

%// Size parameters and random inputs with them
A_nrows    = 5000;
IDlist_len = 5000;
max_userID = 1000;
max_itemID = 1000;
A = [randi(max_userID,A_nrows,1) randi(max_itemID,A_nrows,1) randi(5,A_nrows,2)];
smallUserIDList = randi(max_userID,IDlist_len,1);
smallItemIDList = randi(max_itemID,IDlist_len,1);

disp('---------------------------- With Original Approach')
tic
%//   Original posted code
toc

disp('---------------------------- With Proposed Approach'))
tic
%//   Proposed approach code
toc

The runtimes thus obtained with three sets of datasizes were -

Case #1:

A_nrows    = 500;
IDlist_len = 500;
max_userID = 100;
max_itemID = 100;
---------------------------- With Original Approach
Elapsed time is 0.136630 seconds.
---------------------------- With Proposed Approach
Elapsed time is 0.004163 seconds.

Case #2:

A_nrows    = 5000;
IDlist_len = 5000;
max_userID = 100;
max_itemID = 100;
---------------------------- With Original Approach
Elapsed time is 1.579468 seconds.
---------------------------- With Proposed Approach
Elapsed time is 0.050498 seconds.

Case #3:

A_nrows    = 5000;
IDlist_len = 5000;
max_userID = 1000;
max_itemID = 1000;
---------------------------- With Original Approach
Elapsed time is 1.252294 seconds.
---------------------------- With Proposed Approach
Elapsed time is 0.044198 seconds.

Conclusion: The speedups with the proposed approach over the original loopy code thus seem to be huge!!



回答3:

I think you are trying to remove a fixed set of ratings for a subset of users and count the number of remaining ratings:

Does the following work:

Asub = A(ismember(A(:,1), smallUserIDList),1:2);
Bremove = allcomb(smallUserIDList, smallItemIDList);
Akeep = setdiff(Asub, Bremove, 'rows');
T = varfun(@sum, array2table(Akeep), 'InputVariables', 'Akeep2', 'GroupingVariables', 'Akeep1');
% userStat = T.GroupCount;

you need the allcomb function from the file exchange from matlab central, it gives a cartesian product of two vectors, and is easy to implement anyway.