Python: simple list merging based on intersections

2019-01-01 15:59发布

Consider there are some lists of integers as:

#--------------------------------------
0 [0,1,3]
1 [1,0,3,4,5,10,...]
2 [2,8]
3 [3,1,0,...]
...
n []
#--------------------------------------

The question is to merge lists having at least one common element. So the results only for the given part will be as follows:

#--------------------------------------
0 [0,1,3,4,5,10,...]
2 [2,8]
#--------------------------------------

What is the most efficient way to do this on large data (elements are just numbers)? Is tree structure something to think about? I do the job now by converting lists to sets and iterating for intersections, but it is slow! Furthermore I have a feeling that is so-elementary! In addition, the implementation lacks something (unknown) because some lists remain unmerged sometime! Having said that, if you were proposing self-implementation please be generous and provide a simple sample code [apparently Python is my favoriate :)] or pesudo-code.
Update 1: Here is the code I was using:

#--------------------------------------
lsts = [[0,1,3],
        [1,0,3,4,5,10,11],
        [2,8],
        [3,1,0,16]];
#--------------------------------------

The function is (buggy!!):

#--------------------------------------
def merge(lsts):
    sts = [set(l) for l in lsts]
    i = 0
    while i < len(sts):
        j = i+1
        while j < len(sts):
            if len(sts[i].intersection(sts[j])) > 0:
                sts[i] = sts[i].union(sts[j])
                sts.pop(j)
            else: j += 1                        #---corrected
        i += 1
    lst = [list(s) for s in sts]
    return lst
#--------------------------------------

The result is:

#--------------------------------------
>>> merge(lsts)
>>> [0, 1, 3, 4, 5, 10, 11, 16], [8, 2]]
#--------------------------------------

Update 2: To my experience the code given by Niklas Baumstark below showed to be a bit faster for the simple cases. Not tested the method given by "Hooked" yet, since it is completely different approach (by the way it seems interesting). The testing procedure for all of these could be really hard or impossible to be ensured of the results. The real data set I will use is so large and complex, so it is impossible to trace any error just by repeating. That is I need to be 100% satisfied of the reliability of the method before pushing it in its place within a large code as a module. Simply for now Niklas's method is faster and the answer for simple sets is correct of course.
However how can I be sure that it works well for real large data set? Since I will not be able to trace the errors visually!

Update 3: Note that reliability of the method is much more important than speed for this problem. I will be hopefully able to translate the Python code to Fortran for the maximum performance finally.

Update 4:
There are many interesting points in this post and generously given answers, constructive comments. I would recommend reading all thoroughly. Please accept my appreciation for the development of the question, amazing answers and constructive comments and discussion.

15条回答
人间绝色
2楼-- · 2019-01-01 16:00

Firstly I'm not exactly sure if the benchmarks are fair:

Adding the following code to the start of my function:

c = Counter(chain(*lists))
    print c[1]
"88"

This means that out of all the values in all the lists, there are only 88 distinct values. Usually in the real world duplicates are rare, and you would expect a lot more distinct values. (of course i don't know where your data from so can't make assumptions).

Because Duplicates are more common, it means sets are less likely to be disjoint. This means the set.isdisjoint() method will be much faster because only after a few tests it will find that the sets aren't disjoint.

Having said all that, I do believe the methods presented that use disjoint are the fastest anyway, but i'm just saying, instead of being 20x faster maybe they should only be 10x faster than the other methods with different benchmark testing.

Anyway, i Thought I would try a slightly different technique to solve this, however the merge sorting was too slow, this method is about 20x slower than the two fastest methods using the benchmarking:

I thought I would order everything

import heapq
from itertools import chain
def merge6(lists):
    for l in lists:
        l.sort()
    one_list = heapq.merge(*[zip(l,[i]*len(l)) for i,l in enumerate(lists)]) #iterating through one_list takes 25 seconds!!
    previous = one_list.next()

    d = {i:i for i in range(len(lists))}
    for current in one_list:
        if current[0]==previous[0]:
            d[current[1]] = d[previous[1]]
        previous=current

    groups=[[] for i in range(len(lists))]
    for k in d:
        groups[d[k]].append(lists[k]) #add a each list to its group

    return [set(chain(*g)) for g in groups if g] #since each subroup in each g is sorted, it would be faster to merge these subgroups removing duplicates along the way.


lists = [[1,2,3],[3,5,6],[8,9,10],[11,12,13]]
print merge6(lists)
"[set([1, 2, 3, 5, 6]), set([8, 9, 10]), set([11, 12, 13])]""



import timeit
print timeit.timeit("merge1(lsts)", setup=setup, number=10)
print timeit.timeit("merge4(lsts)", setup=setup, number=10)
print timeit.timeit("merge6(lsts)", setup=setup, number=10)
5000 lists, 5 classes, average size 74, max size 1000
1.26732238315
5000 lists, 5 classes, average size 74, max size 1000
1.16062907437
5000 lists, 5 classes, average size 74, max size 1000
30.7257182826
查看更多
千与千寻千般痛.
3楼-- · 2019-01-01 16:01

Here's my answer. I haven't checked it against today's batch of answers.

The intersection-based algorithms are O(N^2) since they check each new set against all the existing ones, so I used an approach that indexes each number and runs on close to O(N) (if we accept that dictionary lookups are O(1)). Then I ran the benchmarks and felt like a complete idiot because it ran slower, but on closer inspection it turned out that the test data ends up with only a handful of distinct result sets, so the quadratic algorithms don't have a lot work to do. Test it with more than 10-15 distinct bins and my algorithm is much faster. Try test data with more than 50 distinct bins, and it is enormously faster.

(Edit: There was also a problem with the way the benchmark is run, but I was wrong in my diagnosis. I altered my code to work with the way the repeated tests are run).

def mergelists5(data):
    """Check each number in our arrays only once, merging when we find
    a number we have seen before.
    """

    bins = range(len(data))  # Initialize each bin[n] == n
    nums = dict()

    data = [set(m) for m in data ]  # Convert to sets    
    for r, row in enumerate(data):
        for num in row:
            if num not in nums:
                # New number: tag it with a pointer to this row's bin
                nums[num] = r
                continue
            else:
                dest = locatebin(bins, nums[num])
                if dest == r:
                    continue # already in the same bin

                if dest > r:
                    dest, r = r, dest   # always merge into the smallest bin

                data[dest].update(data[r]) 
                data[r] = None
                # Update our indices to reflect the move
                bins[r] = dest
                r = dest 

    # Filter out the empty bins
    have = [ m for m in data if m ]
    print len(have), "groups in result"
    return have


def locatebin(bins, n):
    """
    Find the bin where list n has ended up: Follow bin references until
    we find a bin that has not moved.
    """
    while bins[n] != n:
        n = bins[n]
    return n
查看更多
只若初见
4楼-- · 2019-01-01 16:01

This can be solved in O(n) by using the union-find algorithm. Given the first two rows of your data, edges to use in the union-find are the following pairs: (0,1),(1,3),(1,0),(0,3),(3,4),(4,5),(5,10)

查看更多
裙下三千臣
5楼-- · 2019-01-01 16:02

EDIT: OK, the other questions has been closed, posting here.

Nice question! It's much simpler if you think of it as a connected-components problem in a graph. The following code uses the excellent networkx graph library and the pairs function from this question.

def pairs(lst):
    i = iter(lst)
    first = prev = item = i.next()
    for item in i:
        yield prev, item
        prev = item
    yield item, first

lists = [[1,2,3],[3,5,6],[8,9,10],[11,12,13]]

import networkx
g = networkx.Graph()
for sub_list in lists:
    for edge in pairs(sub_list):
            g.add_edge(*edge)

networkx.connected_components(g)
[[1, 2, 3, 5, 6], [8, 9, 10], [11, 12, 13]]

Explanation

We create a new (empty) graph g. For each sub-list in lists, consider its elements as nodes of the graph and add an edge between them. (Since we only care about connectedness, we don't need to add all the edges -- only adjacent ones!) Note that add_edge takes two objects, treats them as nodes (and adds them if they aren't already there), and adds an edge between them.

Then, we just find the connected components of the graph -- a solved problem! -- and output them as our intersecting sets.

查看更多
高级女魔头
6楼-- · 2019-01-01 16:04

This is slower than the solution offered by Niklas (I got 3.9s on the test.txt instead of 0.5s for his solution), but yields the same result and might be easier to implement in e.g. Fortran, since it doesn't use sets, only sorting of the total amount of elements and then a single run through all of them.

It returns a list with the ids of the merged lists, so also keeps track of empty lists, they stay unmerged.

def merge(lsts):
        # this is an index list that stores the joined id for each list
        joined = range(len(lsts))
        # create an ordered list with indices
        indexed_list = sorted((el,index) for index,lst in enumerate(lsts) for el in lst)
        # loop throught the ordered list, and if two elements are the same and
        # the lists are not yet joined, alter the list with joined id
        el_0,idx_0 = None,None
        for el,idx in indexed_list:
                if el == el_0 and joined[idx] != joined[idx_0]:
                        old = joined[idx]
                        rep = joined[idx_0]
                        joined = [rep if id == old else id for id in joined]
                el_0, idx_0 = el, idx
        return joined
查看更多
临风纵饮
7楼-- · 2019-01-01 16:07

Here's a function (Python 3.1) to check if the result of a merge function is OK. It checks:

  • Are the result sets disjoint? (number of elements of union == sum of numbers of elements)
  • Are the elements of the result sets the same as of the input lists?
  • Is every input list a subset of a result set?
  • Is every result set minimal, i.e. is it impossible to split it into two smaller sets?
  • It does not check if there are empty result sets - I don't know if you want them or not...

.

from itertools import chain

def check(lsts, result):
    lsts = [set(s) for s in lsts]
    all_items = set(chain(*lsts))
    all_result_items = set(chain(*result))
    num_result_items = sum(len(s) for s in result)
    if num_result_items != len(all_result_items):
        print("Error: result sets overlap!")
        print(num_result_items, len(all_result_items))
        print(sorted(map(len, result)), sorted(map(len, lsts)))
    if all_items != all_result_items:
        print("Error: result doesn't match input lists!")
    if not all(any(set(s).issubset(t) for t in result) for s in lst):
        print("Error: not all input lists are contained in a result set!")

    seen = set()
    todo = list(filter(bool, lsts))
    done = False
    while not done:
        deletes = []
        for i, s in enumerate(todo): # intersection with seen, or with unseen result set, is OK
            if not s.isdisjoint(seen) or any(t.isdisjoint(seen) for t in result if not s.isdisjoint(t)):
                seen.update(s)
                deletes.append(i)
        for i in reversed(deletes):
            del todo[i]
        done = not deletes
    if todo:
        print("Error: A result set should be split into two or more parts!")
        print(todo)
查看更多
登录 后发表回答