Recursion with versus without memoization

2020-05-06 14:15发布

问题:

I got homework in school to calculate Catalan number with recursion: 1st without memoization

def catalan_rec(n):
res = 0
if n == 0:
    return 1
else:
    for i in range (n):
        res += (catalan_rec(i))*(catalan_rec(n-1-i))
    return res

2nd with:

def catalan_mem(n, memo = None):
    if memo == None:
        memo = {0: 1}
    res = 0
    if n not in memo:
        for i in range (n):
            res += (catalan_mem(i))*(catalan_mem(n-1-i))
        memo[n] = res
    return memo[n]

The weirdest thing happened to me: the memoization takes twice much time! When it should be the other way around!

Can someone please explain this to me?

回答1:

This question inspired me to investigate the relative speed of various Catalan number algorithms and various memoization schemes. The code below contains functions for the recursive algorithm given in the question as well as a simpler algorithm that only needs one recursive call, which is also easy to implement iteratively. There's also an iterative version based on the binomial coefficient. All of these algorithms are given in the Wikipedia article on Catalan numbers.

It's not easy to get accurate timings for most of the memoized versions. Normally when using the timeit module one performs multiple loops over the function to be tested, but that doesn't give true results here due to caching. To get true results would require clearing the caches, and while that's possible it's a bit messy and slow, so the cache clearing would need to be done outside the timing process to avoid adding the overhead of cache clearing to the time of the actual Catalan number calculations. So this code generates timing info by simply calculating a large Catalan number, with no looping.

As well as the timing code there's also a function, verify(), which verifies that all the Catalan number functions produce the same results, and there's a function that can print the bytecode for each Catalan number function. Both of those functions have been commented out. Note that verify() populates the caches, so calling verify() before time_test() would cause the timing information to be invalid.

The code below was written and tested using Python 2.6.6, but it also runs correctly on Python 3.6.0.

#!/usr/bin/env python

''' Catalan numbers

    Test speeds of various algorithms

    1, 1, 2, 5, 14, 42, 132, 429, 1430, 4862, 16796, 58786, 208012, 742900, ...

    See https://en.wikipedia.org/wiki/Catalan_number
    and http://stackoverflow.com/q/33959795/4014959

    Written by PM 2Ring 2015.11.28
'''

from __future__ import print_function, division

from timeit import Timer
import dis


#Use xrange if running on Python 2
try:
    range = xrange
except NameError:
    pass

def catalan_rec_plain(n):
    ''' no memoization. REALLY slow! Eg, 26 seconds for n=16 '''
    if n < 2:
        return 1

    res = 0
    for i in range(n):
        res += catalan_rec_plain(i) * catalan_rec_plain(n-1-i)
    return res


#Most recursive versions have recursion limit: n=998, except where noted
cache = {0: 1}
def catalan_rec_extern(n):
    ''' memoize with an external cache '''
    if n in cache:
        return cache[n]

    res = 0
    for i in range(n):
        res += catalan_rec_extern(i) * catalan_rec_extern(n-1-i)
    cache[n] = res
    return res


def catalan_rec_defarg(n, memo={0: 1}):
    ''' memoize with a default keyword arg cache '''
    if n in memo:
        return memo[n]

    res = 0
    for i in range(n):
        res += catalan_rec_defarg(i) * catalan_rec_defarg(n-1-i)
    memo[n] = res
    return res


def catalan_rec_funcattr(n):
    ''' memoize with a function attribute cache '''
    memo = catalan_rec_funcattr.memo
    if n in memo:
        return memo[n]

    res = 0
    for i in range(n):
        res += catalan_rec_funcattr(i) * catalan_rec_funcattr(n-1-i)
    memo[n] = res
    return res

catalan_rec_funcattr.memo = {0: 1}


def make_catalan():
    memo = {0: 1}
    def catalan0(n):
        ''' memoize with a simple closure to hold the cache '''
        if n in memo:
            return memo[n]

        res = 0
        for i in range(n):
            res += catalan0(i) * catalan0(n-1-i)
        memo[n] = res
        return res
    return catalan0

catalan_rec_closure = make_catalan()
catalan_rec_closure.__name__ = 'catalan_rec_closure'


#Simple memoization, with initialised cache
def initialise(memo={}):    
    def memoize(f):
        def memf(x):
            if x in memo:
                return memo[x]
            else:
                res = memo[x] = f(x)
                return res
        memf.__name__ = f.__name__
        memf.__doc__ = f.__doc__
        return memf
    return memoize

#maximum recursion depth exceeded at n=499
@initialise({0: 1})
def catalan_rec_decorator(n):
    ''' memoize with a decorator closure to hold the cache '''
    res = 0
    for i in range(n):
        res += catalan_rec_decorator(i) * catalan_rec_decorator(n-1-i)
    return res

# ---------------------------------------------------------------------

#Product formula
# C_n+1 = C_n * 2 * (2*n + 1) / (n + 2)
# C_n = C_n-1 * 2 * (2*n - 1) / (n + 1)

#maximum recursion depth exceeded at n=999
def catalan_rec_prod(n):
    ''' recursive, using product formula '''
    if n < 2:
        return 1
    return (4*n - 2) * catalan_rec_prod(n-1) // (n + 1)

#Note that memoizing here gives no benefit when calculating a single value
def catalan_rec_prod_memo(n, memo={0: 1}):
    ''' recursive, using product formula, with a default keyword arg cache '''
    if n in memo:
        return memo[n]
    memo[n] = (4*n - 2) * catalan_rec_prod_memo(n-1) // (n + 1)
    return memo[n]


def catalan_iter_prod0(n):
    ''' iterative, using product formula '''
    p = 1
    for i in range(3, n + 2):
        p *= 4*i - 6 
        p //= i 
    return p


def catalan_iter_prod1(n):
    ''' iterative, using product formula, with incremented m '''
    p = 1
    m = 6
    for i in range(3, n + 2):
        p *= m
        m += 4 
        p //= i 
    return p

#Add memoization to catalan_iter_prod1
@initialise({0: 1})
def catalan_iter_memo(n):
    ''' iterative, using product formula, with incremented m and memoization '''
    p = 1
    m = 6
    for i in range(3, n + 2):
        p *= m
        m += 4 
        p //= i 
    return p

def catalan_iter_prod2(n):
    ''' iterative, using product formula, with zip '''
    p = 1
    for i, m in zip(range(3, n + 2), range(6, 4*n + 2, 4)):
        p *= m
        p //= i 
    return p


def catalan_iter_binom(n):
    ''' iterative, using binomial coefficient '''
    m = 2 * n
    n += 1
    p = 1
    for i in range(1, n):
        p *= m
        p //= i
        m -= 1
    return p // n


#All the functions, in approximate speed order
funcs = (
    catalan_iter_prod1,
    catalan_iter_memo,
    catalan_iter_prod0,
    catalan_iter_binom,
    catalan_iter_prod2,

    catalan_rec_prod,
    catalan_rec_prod_memo,
    catalan_rec_defarg,
    catalan_rec_closure,
    catalan_rec_extern,
    catalan_rec_decorator,
    catalan_rec_funcattr,
    #catalan_rec_plain,
)

# ---------------------------------------------------------------------

def show_bytecode():
    for func in funcs:
        fname = func.__name__
        print('\n%s' % fname)
        dis.dis(func)

#Check that all functions give the same results
def verify(n):
    range_n = range(n)
    #range_n = [n]
    func = funcs[0]
    table = [func(i) for i in range_n]
    #print(table)
    for func in funcs[1:]:
        print(func.__name__, [func(i) for i in range_n] == table)

def time_test(n):
    ''' Print timing stats for all the functions '''
    res = []
    for func in funcs:
        fname = func.__name__
        print('\n%s: %s' % (fname, func.__doc__))
        setup = 'from __main__ import cache, ' + fname
        cmd = '%s(%d)' % (fname, n)
        t = Timer(cmd, setup)
        r = t.timeit(1)
        print(r)
        res.append((r, fname))

    ##Sort results from fast to slow
    #print()
    #res.sort()
    #for t, fname in res:
        #print('%s:\t%s' % (fname, t))
        ##print('%s,' % fname)


#show_bytecode()

#verify(50)
#verify(997)

time_test(450)

#for i in range(20):
    #print('%2d: %d' % (i, catalan_iter_binom(i)))

typical results

catalan_iter_prod1:  iterative, using product formula, with incremented m 
0.00119090080261

catalan_iter_memo:  iterative, using product formula, with incremented m and memoization 
0.001140832901

catalan_iter_prod0:  iterative, using product formula 
0.00202202796936

catalan_iter_binom:  iterative, using binomial coefficient 
0.00141906738281

catalan_iter_prod2:  iterative, using product formula, with zip 
0.00123286247253

catalan_rec_prod:  recursive, using product formula 
0.00263595581055

catalan_rec_prod_memo:  recursive, using product formula, with a default keyword arg cache 
0.00210690498352

catalan_rec_defarg:  memoize with a default keyword arg cache 
0.46977186203

catalan_rec_closure:  memoize with a simple closure to hold the cache 
0.474807024002

catalan_rec_extern:  memoize with an external cache 
0.47812795639

catalan_rec_decorator:  memoize with a decorator closure to hold the cache 
0.47876906395

catalan_rec_funcattr:  memoize with a function attribute cache 
0.516775131226

The above results were produced by a 2GHz Pentium 4, with minimal system load. However, there is quite a bit of variance from run to run, especially with the faster algorithms.

As you can see, using a default argument for the cache is actually quite a good approach for the double recursion algorithm used in the question. So a cleaned-up version of your recursive version is:

def catalan_rec(n, memo={0: 1}):
    ''' recursive Catalan numbers, with memoization '''
    if n in memo:
        return memo[n]

    res = 0
    for i in range(n):
        res += catalan_rec_defarg(i) * catalan_rec_defarg(n-1-i)
    memo[n] = res
    return res

However, it's much more efficient to use one of the iterative algorithms, eg catalan_iter_prod1. If you intend to call the function multiple times with a high likelihood of repeated arguments then use the memoized version, catalan_iter_memo.

In conclusion, I should mention that it's best to avoid recursion unless it's appropriate to the problem domain (eg when working with recursive data structures like trees). Python cannot perform tail call elimination and it imposes a recursion limit. So if there's an iterative algorithm it's almost always a better choice than a recursive one. Of course, if you're learning about recursion and your teacher wants you to write recursive code then you don't have much choice. :)