Java 8 Stream, How to get Top N count? [closed]

2019-03-21 15:46发布

问题:

I need your advice to simplify this code below. I have a player list with an ID of the games won. I want to extract the 2 best players from this list (the 2 players who have a better amount of match Id) Once extracted, I have to return the initial list to do other operations. I think it is possible to improve this code in terms of optimization or reading. If you can help me.

public class PlayerStatistics {
    int id
    String name;
    int idMatchWon; // key from Match

    // getter , setter
}

public static void main(String[] args) throws Exception {

    List<PlayerStatistics> _players = new ArrayList<PlayerStatistics>();

    _players.add(initialize(1,'John',4));
    _players.add(initialize(2,'Teddy',2));
    _players.add(initialize(3,'Kevin',3));

    // How to get Top 2
    List<PlayerStatistics> _top2Players = extractTop2Players(_players);
}

private List<PlayerStatistics> extractTop2Players (List<PlayerStatistics> _list) {

    List<PlayerStatistics> _topPlayers = new ArrayList<PlayerStatistics>();

    // 1. Group and count 
    Map<String, Long> _players = _list
            .stream()
            .filter(x -> (!"".equals(x.getName()) && x.getName()!= null) )
            .collect(
                    Collectors.groupingBy(
                            PlayerStatistics::getName, Collectors.counting()
                    )
            );
    ;

    // 2 Best Palyers
    Set<String> _sortedPlayers = _players.entrySet().stream()
            .sorted(Map.Entry.comparingByValue(Collections.reverseOrder()))
            .limit(2)
            .map(Entry::getKey)
            .collect(Collectors.toSet())
    ;

    // 3. Rebuild list 
    _topPlayers = _list
            .stream()
            .filter(x -> _sortedPlayers.contains(x.getName()))
            .collect(Collectors.toList())
    ;

    return _topPlayers;
}


private PlayerStatistics initialize (int id, String name, int year, int month, int won, int lost) {
    return 
        new PlayerStatistics()
            .withId(id)
            .withName(name)
            .withIdMatchWon(won)
        );
}

回答1:

First of all, let's state that your code is absolutely correct. It does what needs to be done and it's even optimized by using sets. It can be further improved in two ways, though:

  1. Time complexity: you are sorting the whole dataset, which has a time complexity of O(mlogm), with m being the size of your initial list of players. Immediately, you are taking the top N elements of your list, with N << m.

    Below I'm showing a way to improve time complexity of the algorithm to O(mlogN), which means that in your specific case it would become O(m) (this is because N=2, so logN=log2=1).

  2. You are traversing the dataset 3 times: first you're iterating the list of players to create the map of counts, then you are iterating this map to get a set with the top N players, and finally you're iterating the list of players again to check whether each player belongs to the set of top N players.

    This can be improved to perform only 2 passes over the dataset: the first one to create a map of counts (similar to what you've already done) and the other one to create a structure that will keep only the top N elements, sorted by count descending, with the result ready to be returned once the traversal has finished.

Important: the solution below requires that your PlayerStatistics class implements the hashCode and equals methods consistently.

First we have a generic method topN that (not surprisingly) extracts the top N elements from any given map. It does this by comparing its entries by value, descending (in this version, values V must be Comparable<V>, but this algorithm can be easily extended to support values that don't implement Comparable<V> by providing a custom Comparator<V>):

public static 
<K, V extends Comparable<? super V>, T extends Comparable<? super T>>
Collection<K> 
topN(
        Map<K, V> map, 
        int N,
        Function<? super K, ? extends T> tieBreaker) {

    TreeMap<Map.Entry<K, V>, K> topN = new TreeMap<>(
        Map.Entry.<K, V>comparingByValue()      // by value descending, then by key
            .reversed()                         // to allow entries with duplicate values
            .thenComparing(e -> tieBreaker.apply(e.getKey())));

    map.entrySet().forEach(e -> {
      topN.put(e, e.getKey());
      if (topN.size() > N) topN.pollLastEntry();
    });

    return topN.values();
}

Here the topN TreeMap behaves as a priority queue of size N (though we add up to N+1 elements). First we put the entry into the topN map, then, if the map has more than N entries, we immediately invoke the pollLastEntry method on it, which removes the entry with the lowest priority (according to the order of the keys of the TreeMap). This guarantees that upon traversal, the topN map will only contain the top N entries, already sorted.

Note that I'm using a comparator that first sorts the TreeMap<Map.Entry<K, V>, K> by values V in descending order, and then by keys K. This is achieved with the help of the Function<? super K, ? extends T> tieBreaker function, which transforms each key K to a value T that must be Comparable<T>. All this allows the map to contain entries with duplicate values of V, without requiring keys K to also be Comparable<K>.

Finally, you'd use the above method as follows:

Map<PlayerStatistics, Long> counts = yourInitialListOfPlayers.stream()
    .filter(x -> !"".equals(x.getName()) && x.getName() != null)
    .collect(Collectors.groupingBy(x -> x, Collectors.counting()));

Collection<PlayerStatistics> top2 = topN(counts, 2, PlayerStatistics::getName);