Memoization in Haskell?

2019-01-04 04:21发布

Any pointers on how to solve efficiently the following function in Haskell, for large numbers (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

I've seen examples of memoization in Haskell to solve fibonacci numbers, which involved computing (lazily) all the fibonacci numbers up to the required n. But in this case, for a given n, we only need to compute very few intermediate results.

Thanks

8条回答
叼着烟拽天下
2楼-- · 2019-01-04 05:15

Edward's answer is such a wonderful gem that I've duplicated it and provided implementations of memoList and memoTree combinators that memoize a function in open-recursive form.

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

f :: (Integer -> Integer) -> Integer -> Integer
f mf 0 = 0
f mf n = max n $ mf (div n 2) +
                 mf (div n 3) +
                 mf (div n 4)


-- Memoizing using a list

-- The memoizing functionality depends on this being in eta reduced form!
memoList :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoList f = memoList_f
  where memoList_f = (memo !!) . fromInteger
        memo = map (f memoList_f) [0..]

faster_f :: Integer -> Integer
faster_f = memoList f


-- Memoizing using a tree

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- The memoizing functionality depends on this being in eta reduced form!
memoTree :: ((Integer -> Integer) -> Integer -> Integer) -> Integer -> Integer
memoTree f = memoTree_f
  where memoTree_f = index memo
        memo = fmap (f memoTree_f) nats

fastest_f :: Integer -> Integer
fastest_f = memoTree f
查看更多
不美不萌又怎样
3楼-- · 2019-01-04 05:21

A couple years later, I looked at this and realized there's a simple way to memoize this in linear time using zipWith and a helper function:

dilate :: Int -> [x] -> [x]
dilate n xs = replicate n =<< xs

dilate has the handy property that dilate n xs !! i == xs !! div i n.

So, supposing we're given f(0), this simplifies the computation to

fs = f0 : zipWith max [1..] (tail $ fs#/2 .+. fs#/3 .+. fs#/4)
  where (.+.) = zipWith (+)
        infixl 6 .+.
        (#/) = flip dilate
        infixl 7 #/

Looking a lot like our original problem description, and giving a linear solution (sum $ take n fs will take O(n)).

查看更多
登录 后发表回答