Find Nth term of provided sequence

2020-03-07 12:18发布

f(0) = p

f(1) = q

f(2) = r

for n > 2

f(n) = af(n-1) + bf(n-2) + c*f(n-3) + g(n)

where g(n) = n* n* (n+1)

p,q,r,a,b,c are given The question is, How to find the nth term of this series.

Please help me in finding a better solution for this.

I have tried solving this using recursion. But that way is consuming high memory.

2条回答
爷的心禁止访问
2楼-- · 2020-03-07 12:33

The problem is that for each call to f with n > 2, it results in three extra calls to f. For example if we call f(5), we get the following calls:

- f(5)
    - f(4)
        - f(3)
            - f(2)
            - f(1)
            - f(0)
            - g(3)
        - f(2)
        - f(1)
        - g(4)
    - f(3)
        - f(2)
        - f(1)
        - f(0)
        - g(3)
    - f(2)
    - g(5)

We thus make one call f(5), one call to f(4), two calls to f(3), four calls to f(2), three calls to f(1), and two calls to f(0).

Since we make multiple calls to for example f(3), it thus means that each time this will cost resources, especially since f(3) itself will make extra calls.

We can let Python store the result of a function call, and return the result, for example with the lru_cache [Python-doc]. This technique is called memoization:

from functools import lru_cache

def g(n):
    return n * n * (n+1)

@lru_cache(maxsize=32)
def f(n):
    if n <= 2:
        return (p, q, r)[n]
    else:
        return a*f(n-1) + b*f(n-2) + c*f(n-3) + g(n)

This will result in a call graph like:

- f(5)
    - f(4)
        - f(3)
            - f(2)
            - f(1)
            - f(0)
            - g(3)
        - g(4)
    - g(5)

So now we will only calculate f(3) once, the lru_cache will store it in the cache, and if we call f(3) a second time, we will never evaluate f(3) itself, the cache will return the pre-computed value.

The above here can however be optimized, since we each time call f(n-1), f(n-2) and f(n-3), we only need to store the last three values, and each time calculate the next value based on the last three values, and shift the variables, like:

def f(n):
    if n <= 2:
        return (p, q, r)[n]
    f3, f2, f1 = p, q, r
    for i in range(3, n+1):
        f3, f2, f1 = f2, f1, a * f1 + b * f2 + c * f3 + g(i)
    return f1
查看更多
做自己的国王
3楼-- · 2020-03-07 12:46

A better way than recursion would be memoization. You just need to know the last three values for f(n). A solution in pseudocode could look like this:

if n == 0:
    return p
else if n == 1:
    return q
else if n == 2:
    return r
else:    
    f_n-3 = p
    f_n-2 = q
    f_n-1 = r
    for i from 3 to n:
        f_new = a * f_n-1 + b * f_n-2 + c * f_n-3 + g(n)
        fn-1 = fn-2
        fn-2 = fn-3
        fn-3 = f_new

return f_new

This way you don't need to call the method recursively and keep all the values, that were calculated, on the stack, but just keep 4 variables in your memeory.

This should calculate much faster and use much less memory.

查看更多
登录 后发表回答