Implementing a memoization function in Haskell

2019-08-02 15:02发布

I'm fairly new to Haskell, and I'm trying to implement a basic memoization function which uses a Data.Map to store computed values. My example is for Project Euler Problem 15, which involves computing the number of possible paths from 1 corner to the other in a 20x20 grid.

This is what I have so far. I haven't tried compiling yet because I know it won't compile. I'll explain below.

import qualified Data.Map as Map

main = print getProblem15Value

getProblem15Value :: Integer
getProblem15Value = getNumberOfPaths 20 20

getNumberOfPaths :: Integer -> Integer -> Integer
getNumberOfPaths x y = memoize getNumberOfPaths' (x,y)
where getNumberOfPaths' mem (0,_) = 1
      getNumberOfPaths' mem (_,0) = 1
      getNumberOfPaths' mem (x,y) = (mem (x-1,y)) + (mem (x,y-1))

memoize :: ((a -> b) -> a -> b) -> a -> b
memoize func x = fst $ memoize' Map.Empty func x
    where memoize' map func' x' = case (Map.lookup x' map) of (Just y) -> (y, map)
                                                              Nothing -> (y', map'')
           where y' = func' mem x'
                 mem x'' = y''
                 (y'', map') = memoize' map func' x''
                 map'' = Map.insert x' y' map'

So basically, the way I have this structured is that memoize is a combinator (by my understanding). The memoization works because memoize provides a function (in this case getNumberOfPaths') with a function to call (mem) for recursion, instead of having getNumberOfPaths' call itself, which would remove the memoization after the first iteration.

My implementation of memoize takes a function (in this case getNumberOfPaths') and an initial value (in this case a tuple (x,y) representing the number of grid cell distances from the other corner of the grid). It calls memoize' which has the same structure, but includes an empty Map to hold values, and returns a tuple containing the return value and a new computed Map. memoize' does a map lookup and returns the value and the original map if there is a value present. If there is no value present, it returns the computed value and a new map.

This is where my algorithm breaks down. To compute the new value, I call func' (getNumberOfPaths') with mem and x'. mem simply returns y'', where y'' is contained in the result of calling memoize' again. memoize' also returns a new map, to which we then add the new value and use as the return value of memoize'.

The issue here is that the line (y'', map') = memoize' map func' x'' should be under mem because it's dependent on x'', which is a parameter of mem. I can certainly do that, but then I will lose the map' value, which I need because it contains memoized values from intermediate computations. However, I don't want to introduce the Map into the return value of mem because then the function passed to memoize will have to handle the Map.

Sorry if that sounded confusing. A lot of this ultra-high-order functional stuff is confusing to me.

I'm sure that there is a way to do this. What I want is a generic memoize function that allows recursive calling exactly like in the definition of getNumberOfPaths, where the computation logic doesn't have to care exactly how the memoization is done.

5条回答
贪生不怕死
2楼-- · 2019-08-02 15:26

This might not directly help you implement memoization, but you can use someone else's... monad-memo. Adapting one of their examples...

{-# LANGUAGE FlexibleContexts #-}

import Control.Monad.Memo

main = print $ startEvalMemo (getNumberOfPaths 20 20)

getNumberOfPaths :: (MonadMemo (Integer, Integer) Integer m) => Integer -> Integer -> m Integer
getNumberOfPaths 0 _ = return 1
getNumberOfPaths _ 0 = return 1
getNumberOfPaths x y = do
  n1 <- for2 memo getNumberOfPaths (x-1) y
  n2 <- for2 memo getNumberOfPaths x (y-1)
  return (n1 + n2)

... I suspect to implement something similar you can peek in their source https://github.com/EduardSergeev/monad-memo

查看更多
Bombasti
3楼-- · 2019-08-02 15:30

However, I don't want to introduce the Map into the return value of mem because then the function passed to memoize will have to handle the Map.

If I'm understanding, you will have to do something like this, at least if your aim is to store the memoized values in a map that gets copied on each new value found. Drawing attention to something that I don't think makes sense in terms of memoization...

getNumberOfPaths' mem (x,y) = (mem (x-1,y)) + (mem (x,y-1))

... means that any memoization from one branch mem (x-1,y), can't be used in the other mem (x,y-1), because the same mem will be used in both, containing the same information, whatevever value/function mem ends up being. You have to, somehow, pass memoized values from one to the other. This means the function called to recurse can't just return an Integer: it has to return an Integer along with some knowledge of the memoized values found along with that Integer.

There are a number of ways of doing this. Although perhaps undesirable due to the spread of the details of the memoization, you can pass the map about explicitly.

getNumberOfPaths :: (Integer, Integer) -> Integer
getNumberOfPaths (x, y) = snd $ memoize Map.empty getNumberOfPaths' (x, y) 

getNumberOfPaths' :: Map.Map (Integer, Integer) Integer -> (Integer, Integer) -> (Map.Map (Integer, Integer) Integer, Integer)
getNumberOfPaths' map (0,_) = (map, 1)
getNumberOfPaths' map (_,0) = (map, 1)
getNumberOfPaths' map (x,y) = (map'', first + second) where
  (map',   first) = memoize map  getNumberOfPaths' (x-1, y)
  (map'', second) = memoize map' getNumberOfPaths' (x, y-1)

memoize :: Ord a => Map.Map a b -> (Map.Map a b -> a -> (Map.Map a b, b)) -> a -> (Map.Map a b, b)
memoize map f x = case Map.lookup x map of
  (Just y) -> (map, y)
  Nothing  -> (map'', y) where
    (map', y) = f map x
    map''     = Map.insert x y map'

The getNumberOfPaths' does need to pass the map about, and needs to know it's signature, but at least it doesn't need to interact with the map: this is done in memoize, so I don't think it's that bad.

I think if you just wanted to be passing a function around, you can. You can use a chain of functions as a poor-man's map, but they do have to return a Maybe...

getNumberOfPaths :: (Integer, Integer) -> Integer
getNumberOfPaths (x, y) = snd $ memoize (const Nothing) getNumberOfPaths' (x, y) 

getNumberOfPaths' :: ((Integer, Integer) -> Maybe Integer) -> (Integer, Integer) -> ((Integer, Integer) -> Maybe Integer, Integer)
getNumberOfPaths' mem (0,_) = (mem, 1)
getNumberOfPaths' mem (_,0) = (mem, 1)
getNumberOfPaths' mem (x,y) = (mem'', first + second) where
  (mem',   first) = memoize mem  getNumberOfPaths' (x-1, y)
  (mem'', second) = memoize mem' getNumberOfPaths' (x, y-1)

memoize :: Eq a => (a -> Maybe b) -> ((a-> Maybe b) -> a -> ((a -> Maybe b), b)) -> a -> ((a -> Maybe b), b)
memoize mem f x = case mem x of
  (Just y) -> (mem, y)
  Nothing  -> (mem'', y) where
    (mem', y) = f mem x
    mem''     = \x' -> if x' == x then Just y else mem' x'

I wonder if you wanted to both a) use a map to store the values, and b) pass a function about as mem. However, I suspect this would be tricky, as, while you can pass a function that extracts from a map and return the extracted value, you can't then extract the map from this function to insert something into the map.

There is also the possibility of creating a monad for this (or using State). However, that may be left to another answer.

查看更多
戒情不戒烟
4楼-- · 2019-08-02 15:42

Provided your inputs are small enough, one thing you can do is allocate the memo table as an Array instead of a Map, containing all the results ahead of time, but calculated lazily:

import Data.Array ((!), array)

numPaths :: Integer -> Integer -> Integer
numPaths w h = get (w - 1) (h - 1)
  where

    table = array (0, w * h)
      [ (y * w + x, go x y)
      | y <- [0 .. h - 1]
      , x <- [0 .. w - 1]
      ]

    get x y = table ! fromInteger (y * w + x)

    go 0 _ = 1
    go _ 0 = 1
    go x y = get (x - 1) y + get x (y - 1)

You can also split this into separate functions if you prefer:

numPaths w h = withTable w h go (w - 1) (h - 1)
  where
    go mem 0 _ = 1
    go mem _ 0 = 1
    go mem x y = mem (x - 1) y + mem x (y - 1)

withTable w h f = f'
  where
    f' = f get
    get x y = table ! fromInteger (y * w + x)
    table = makeTable w h f'

makeTable w h f = array (0, w * h)
  [ (y * w + x, f x y)
  | y <- [0 .. w - 1]
  , x <- [0 .. h - 1]
  ]

And I won’t spoil it for you, but there’s also a non-recursive formula for the answer.

查看更多
Animai°情兽
5楼-- · 2019-08-02 15:48

What I want is a generic memoize function that allows recursive calling exactly like in the definition of getNumberOfPaths, where the computation logic doesn't have to care exactly how the memoization is done.

The State monad lends itself well to handling updates to a state, for example updates to a map of memoized values, without having to pass it around explicitly in the "business logic" part of the code, as the other answer at https://stackoverflow.com/a/44492608/1319998 does.

In terms of separating the details of the memoization from the recursive function, you can hide the fact that a map, and even State, are being used behind a type. All the definition of the recursive function needs to know is that it must return a MyMemo a b, and instead of calling itself directly, it must pass itself and the next arguments to myMemo

import qualified Data.Map as Map
import Control.Monad.State.Strict

main = print $ runMyMemo getNumberOfPaths (20, 20)

getNumberOfPaths :: (Integer, Integer) -> MyMemo (Integer, Integer) Integer
getNumberOfPaths (0, _) = return 1
getNumberOfPaths (_, 0) = return 1
getNumberOfPaths (x, y) = do
  n1 <- myMemo getNumberOfPaths (x-1,y)
  n2 <- myMemo getNumberOfPaths (x,y-1)
  return (n1 + n2)

-------

type MyMemo a b = State (Map.Map a b) b

myMemo :: Ord a => (a -> MyMemo a b) -> a -> MyMemo a b
myMemo f x = gets (Map.lookup x) >>= maybe y' return
  where
    y' = do
      y <- f x
      modify $ Map.insert x y
      return y

runMyMemo :: Ord a => (a -> MyMemo a b) -> a -> b
runMyMemo f x = evalState (f x) Map.empty

The above is essentially a roll-your-own version of https://stackoverflow.com/a/44478219/1319998 (well, rolling on top of State).


Thanks to https://stackoverflow.com/a/44515364/1319998 for suggestions on the code in myMemo

查看更多
做自己的国王
6楼-- · 2019-08-02 15:49

You won't be able to implement memoize :: ((a -> b) -> a -> b) -> a -> b. In order to store the result for some a you're going to need a place in memory for that a, which means you're going to need to have some idea of what those as are.

A ham-fisted approach would be to add a type-class for the types you know all the values of, like Universe.

class Universe a where
    universe :: [a]

You could then implement memoize :: (Ord a, Universe a) => ((a -> b) -> a -> b) -> a -> b by building a Map that contains a b value for every a value in universe :: [a], making the memoed function by passing the map lookup to the func, and populating the bs by declaring them to use the memoed function.

This won't work for Integer because there aren't a finite number of them. It won't even work for Int because there are too many of them. To memoize types like Integer you can use the approach used in MemoTrie. Build a lazy infinite data structure that holds the values at the leaf.

Here's one possible structure for Integers.

data IntegerTrie b = IntegerTrie {
    negative :: [b],
    zero :: b,
    positive :: [b]
}

A more efficient structure would allow for jumping deep into the trie to avoid exponential time lookup. For Integers MemoTrie takes the approach of converting the keys into lists of bits with some pair of function a -> [Bool] and [Bool] -> a and using approximately the following trie.

data BitsTrie b = BitsTrie {
    nil :: b,
    false :: BitsTrie b,
    true :: BitsTrie b
}

MemoTrie goes on to abstract over the types that have some associate trie that can be used to memoize them and provides ways to compose them together.

查看更多
登录 后发表回答