Python: how to check if an item was added to a set

2020-02-12 01:16发布

I was wondering if there was a clear/concise way to add something to a set and check if it was added without 2x hashes & lookups.

this is what you might do, but it has 2x hash's of item

if item not in some_set:  # <-- hash & lookup
    some_set.add(item)    # <-- hash & lookup, to check the item already is in the set

    other_task()

This works with a single hash and lookup but is a bit ugly.

some_set_len = len(some_set)
some_set.add(item)
if some_set_len != len(some_set):

    other_task()

Is there a better way to do this using Python's set api?

2条回答
smile是对你的礼貌
2楼-- · 2020-02-12 02:02

I don't think there's a built-in way to do this. You could, of course, write your own function:

def do_add(s, x):
  l = len(s)
  s.add(x)
  return len(s) != l

s = set()
print(do_add(s, 1))
print(do_add(s, 2))
print(do_add(s, 1))
print(do_add(s, 2))
print(do_add(s, 4))

Or, if you prefer cryptic one-liners:

def do_add(s, x):
  return len(s) != (s.add(x) or len(s))

(This relies on the left-to-right evaluation order and on the fact that set.add() always returns None, which is falsey.)

All this aside, I would only consider doing this if the double hashing/lookup is demonstrably a performance bottleneck and if using a function is demonstrably faster.

查看更多
ら.Afraid
3楼-- · 2020-02-12 02:02

Dictionaries have the nice setdefault function to avoid a whole class of problems related to the "double lookup" mentioned in the question. Since, in CPython at least, most of the set code is shared with dictionaries, I tried using that when working with a very large set (500k+ add, +/- 10% duplicates entries).

In addition, in order to reduce the overhead implied by the Python symbol name lookup, I wrapped that in a higher-order function so the compiler will build a closure and so will be able to use the index-based LOAD_FAST/LOAD_DEREF opcodes instead of the more expensive name lookup based LOAD_ATTR/LOAD_GLOBAL:

def cache():
  s = {}
  setdefault = s.setdefault
  n = 0

  def add(x):
    nonlocal n
    n+=1
    return setdefault(x,n) != n

  return add


# Usage
cached = cache()

for i in my_large_generator_with_duplicates():
  if not cached(i):
    do_something()

In my particular use case, this solution runs more than 20% faster than the one suggested in the other answer. Of course, your mileage may vary, so you should run your own tests.


For reference, here are the disassembled code of both solution (Python3.5 running on Linux):

def do_add(s, x):
  l = len(s)
  s.add(x)
  return len(s) != l


dis.dis(do_add)

 20           0 LOAD_GLOBAL              0 (len)                                                                                                                                                                    
              3 LOAD_FAST                0 (s)                                                                                                                                                                      
              6 CALL_FUNCTION            1 (1 positional, 0 keyword pair)                                                                                                                                           
              9 STORE_FAST               2 (l)                                                                                                                                                                      

 21          12 LOAD_FAST                0 (s)                                                                                                                                                                      
             15 LOAD_ATTR                1 (add)                                                                                                                                                                    
             18 LOAD_FAST                1 (x)                                                                                                                                                                      
             21 CALL_FUNCTION            1 (1 positional, 0 keyword pair)                                                                                                                                           
             24 POP_TOP                                                                                                                                                                                             

 22          25 LOAD_GLOBAL              0 (len)                                                                                                                                                                    
             28 LOAD_FAST                0 (s)                                                                                                                                                                      
             31 CALL_FUNCTION            1 (1 positional, 0 keyword pair)
             34 LOAD_FAST                2 (l)
             37 COMPARE_OP               3 (!=)
def cache():
  s = {}
  setdefault = s.setdefault
  n = 0

  def add(x):
    nonlocal n
    n+=1
    return setdefault(x,n) != n

  return add


dis.dis(cache.__code__.co_consts[2])


 13           0 LOAD_DEREF               0 (n)
              3 LOAD_CONST               1 (1)
              6 INPLACE_ADD
              7 STORE_DEREF              0 (n)

 14          10 LOAD_DEREF               1 (setdefault)
             13 LOAD_FAST                0 (x)
             16 LOAD_DEREF               0 (n)
             19 CALL_FUNCTION            2 (2 positional, 0 keyword pair)
             22 LOAD_DEREF               0 (n)
             25 COMPARE_OP               3 (!=)
查看更多
登录 后发表回答