Union find implementation using Python

2020-01-29 10:11发布

问题:

So here's what I want to do: I have a list that contains several equivalence relations:

l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]

And I want to union the sets that share one element. Here is a sample implementation:

def union(lis):
  lis = [set(e) for e in lis]
  res = []
  while True:
    for i in range(len(lis)):
      a = lis[i]
      if res == []:
        res.append(a)
      else:
        pointer = 0 
        while pointer < len(res):
          if a & res[pointer] != set([]) :
            res[pointer] = res[pointer].union(a)
            break
          pointer +=1
        if pointer == len(res):
          res.append(a)
     if res == lis:
      break
    lis,res = res,[]
  return res

And it prints

[set([1, 2, 3, 6, 7]), set([4, 5])]

This does the right thing but is way too slow when the equivalence relations is too large. I looked up the descriptions on union-find algorithm: http://en.wikipedia.org/wiki/Disjoint-set_data_structure but I still having problem coding a Python implementation.

回答1:

Solution that runs in O(n) time

def indices_dict(lis):
    d = defaultdict(list)
    for i,(a,b) in enumerate(lis):
        d[a].append(i)
        d[b].append(i)
    return d

def disjoint_indices(lis):
    d = indices_dict(lis)
    sets = []
    while len(d):
        que = set(d.popitem()[1])
        ind = set()
        while len(que):
            ind |= que 
            que = set([y for i in que 
                         for x in lis[i] 
                         for y in d.pop(x, [])]) - ind
        sets += [ind]
    return sets

def disjoint_sets(lis):
    return [set([x for i in s for x in lis[i]]) for s in disjoint_indices(lis)]

How it works:

>>> lis = [(1,2),(2,3),(4,5),(6,7),(1,7)]
>>> indices_dict(lis)
>>> {1: [0, 4], 2: [0, 1], 3: [1], 4: [2], 5: [2], 6: [3], 7: [3, 4]})

indices_dict gives a map from an equivalence # to an index in lis. E.g. 1 is mapped to index 0 and 4 in lis.

>>> disjoint_indices(lis)
>>> [set([0,1,3,4], set([2])]

disjoint_indices gives a list of disjoint sets of indices. Each set corresponds to indices in an equivalence. E.g. lis[0] and lis[3] are in the same equivalence but not lis[2].

>>> disjoint_set(lis)
>>> [set([1, 2, 3, 6, 7]), set([4, 5])]

disjoint_set converts disjoint indices into into their proper equivalences.


Time complexity

The O(n) time complexity is difficult to see but I'll try to explain. Here I will use n = len(lis).

  1. indices_dict certainly runs in O(n) time because only 1 for-loop

  2. disjoint_indices is the hardest to see. It certainly runs in O(len(d)) time since the outer loop stops when d is empty and the inner loop removes an element of d each iteration. now, the len(d) <= 2n since d is a map from equivalence #'s to indices and there are at most 2n different equivalence #'s in lis. Therefore, the function runs in O(n).

  3. disjoint_sets is difficult to see because of the 3 combined for-loops. However, you'll notice that at most i can run over all n indices in lis and x runs over the 2-tuple, so the total complexity is 2n = O(n)



回答2:

I think this is an elegant solution, using the built in set functions:

#!/usr/bin/python3

def union_find(lis):
    lis = map(set, lis)
    unions = []
    for item in lis:
        temp = []
        for s in unions:
            if not s.isdisjoint(item):
                item = s.union(item)
            else:
                temp.append(s)
        temp.append(item)
        unions = temp
    return unions



if __name__ == '__main__':
    l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]
    print(union_find(l))

It returns a list of sets.



回答3:

Perhaps something like this?

#!/usr/local/cpython-3.3/bin/python

import copy
import pprint
import collections

def union(list_):
    dict_ = collections.defaultdict(set)

    for sublist in list_:
        dict_[sublist[0]].add(sublist[1])
        dict_[sublist[1]].add(sublist[0])

    change_made = True
    while change_made:
        change_made = False
        for key, values in dict_.items():
            for value in copy.copy(values):
                for element in dict_[value]:
                    if element not in dict_[key]:
                        dict_[key].add(element)
                        change_made = True

    return dict_

list_ = [ [1, 2], [2, 3], [4, 5], [6, 7], [1, 7] ]
pprint.pprint(union(list_))


回答4:

This works by completely exhausting one equivalence at a time. When an element finds it's equivalence it is removed from the original set and no longer searched.

def equiv_sets(lis):
    s = set(lis)
    sets = []

    #loop while there are still items in original set
    while len(s):
        s1 = set(s.pop())
        length = 0
        #loop while there are still equivalences to s1
        while( len(s1) != length):
            length = len(s1)
            for v in list(s):
                if v[0] in s1 or v[1] in s1:
                    s1 |= set(v)
                    s  -= set([v])
        sets += [s1]
    return sets

print equiv_sets([(1,2),(2,3),(4,5),(6,7),(1,7)])

OUTPUT: [set([1, 2, 3, 6, 7]), set([4, 5])]