Passing a function with two arguments to filter()

2020-07-06 06:16发布

问题:

Given the following list:

DNA_list = ['ATAT', 'GTGTACGT', 'AAAAGGTT']

I want to filter strings longer than 3 characters. I achieve this with the following code:

With for loop:

long_dna = []
for element in DNA_list:
    length = len(element)
    if int(length) > 3:
        long_dna.append(element)
print long_dna

But I want my code to be more general, so I can later filter strings of any length, so I use a function and for loop:

def get_long(dna_seq, threshold):
    return len(dna_seq) > threshold

long_dna_loop2 = []
for element in DNA_list:
    if get_long(element, 3) is True:
        long_dna_loop2.append(element)
print long_dna_loop2

I want to achieve the same generality using filter() but I cannot achieve this. If I use the above function get_long(), I simply cannot pass arguments to it when I use it with filter(). Is it just not possible or is there a way around it?

My code with filter() for the specific case:

def is_long(dna):
        return len(dna) > 3

    long_dna_filter = filter(is_long, DNA_list)

回答1:

Use lambda to provide the threshold, like this:

filter(lambda seq: get_long(seq, 3),
       dna_list)


回答2:

What you are trying to do is known as partial function application: you have a function with multiple arguments (in this case, 2) and want to get a function derived from it with one or more arguments fixed, which you can then pass to filter.

Some languages (especially functional ones) have this functionality "built in". In python, you can use lambdas to do this (as others have shown) or you can use the functools library. In particular, functools.partial:

The partial() is used for partial function application which “freezes” some portion of a function’s arguments and/or keywords resulting in a new object with a simplified signature. For example, partial() can be used to create a callable that behaves like the int() function where the base argument defaults to two:

>>> from functools import partial
>>> basetwo = partial(int, base=2)
>>> basetwo.__doc__ = 'Convert base 2 string to an int.'
>>> basetwo('10010')
18

So you can do:

filter(functools.partial(get_long, treshold=13), DNA_list)


回答3:

Do you need to use filter()? Why not use a more Pythonic list comprehension?

Example:

>>> DNA_list = ['ATAT', 'GTGTACGT', 'AAAAGGTT']
>>> threshold = 3
>>> long_dna = [dna_seq for dna_seq in DNA_list if len(dna_seq) > threshold]
>>> long_dna
['ATAT', 'GTGTACGT', 'AAAAGGTT']

>>> threshold = 4
>>> [dna_seq for dna_seq in DNA_list if len(dna_seq) > threshold]
['GTGTACGT', 'AAAAGGTT']

This method has the advantage that it's trivial to convert it to a generator which can provide improved memory and execution depending on your application, e.g. if you have a lot of DNA sequences, and you want to iterate over them, realising them as a list will consume a lot of memory in one go. The equivalent generator simply requires replacing square brackets [] with round brackets ():

>>> long_dna = (dna_seq for dna_seq in DNA_list if len(dna_seq) > threshold)
<generator object <genexpr> at 0x7f50de229cd0>
>>> list(long_dna)
['GTGTACGT', 'AAAAGGTT']

In Python 2 this performance improvement is not an option with filter() because it returns a list. In Python 3 filter() returns a filter object more akin to a generator.



回答4:

You can make is_long return a function, which can accept dna, like this

>>> def is_long(length):
...     return lambda dna: len(dna) > length
... 

and then use it in filter, like this

>>> filter(is_long(3), DNA_list)
['ATAT', 'GTGTACGT', 'AAAAGGTT']
>>> filter(is_long(4), DNA_list)
['GTGTACGT', 'AAAAGGTT']

Note: Don't use is operator to compare booleans or numbers. Instead rely on the truthiness of the data as much as possible. So, in your case, you could have written your second version like this

if get_long(element, 3):
    long_dna_loop2.append(element)

Quoting programming recommendations in PEP-8,

Don't compare boolean values to True or False using == .

 Yes:   if greeting:
 No:    if greeting == True:
 Worse: if greeting is True:


回答5:

You can have a more general case.

Since function is an object in python, you can create another function, which returns the function you want.

def f(threshhold):
    def g(x):
        return len(x)>threshhold
    return g #return a function

this_function = f(3)

DNA_list = ['ATAT', 'GTGTACGT', 'AAAAGGTT','AAA','AAAA']
filter(this_function, DNA_list)

output: ['ATAT', 'GTGTACGT', 'AAAAGGTT', 'AAAA']

The g is what you really want and f is the function that create it.



回答6:

Here are a couple of more ways using lambda. The first one uses a default keyword argument to hold the desired length. The second simply embeds the desired length in the lambda body.

#Create a list of strings
s = 'abcdefghi'
data = [s[:i+1] for i in range(len(s))]
print data

thresh = 3
print filter(lambda seq, n=thresh: len(seq) > n, data)

print filter(lambda seq: len(seq) > 5, data)

output

['a', 'ab', 'abc', 'abcd', 'abcde', 'abcdef', 'abcdefg', 'abcdefgh', 'abcdefghi']
['abcd', 'abcde', 'abcdef', 'abcdefg', 'abcdefgh', 'abcdefghi']
['abcdef', 'abcdefg', 'abcdefgh', 'abcdefghi']

In the first example you could also do:

print filter(lambda seq, n=3: len(seq) > n, data)

Similarly, in the second example you could replace the literal 5 with a local (or global) variable, eg:

thresh = 5
print filter(lambda seq: len(seq) > thresh, data)


回答7:

You could always create a callable that returns a callable suitable for comparisons done by filter as the following example shows:

def main():
    dna_list = ['A', 'CA', 'TGATGATAC', 'GGGTAAAATC', 'TCG', 'AGGTCGCT', 'TT',
                'GGGTTGGA', 'C', 'TTGGAGGG']
    print('\n'.join(filter(length_at_least(3), dna_list)))


def length_at_least(value):
    return lambda item: len(item) >= value

# length_at_least = lambda value: lambda item: len(item) >= value

if __name__ == '__main__':
    main()


回答8:

I used different solution using inner function and nonlocal scope like below. I have modified this original code for understanding as my code is different.

Hope this helps. :)

 def outerfun():
    charlimit = 3
    def is_long(dna):
        nonlocal charlimit
        return len(dna) > charlimit
    long_dna_filter = filter(is_long, DNA_list)     
    return long_dna_filter