-->

Using dynamic programming in Haskell? [Warning: Pr

2020-07-11 09:10发布

问题:

In solving projecteuler.net's problem #31 [SPOILERS AHEAD] (counting the number of ways to make 2£ with the British coins), I wanted to use dynamic programming. I started with OCaml, and wrote the short and very efficient following programming:

open Num

let make_dyn_table amount coins =
  let t = Array.make_matrix (Array.length coins) (amount+1) (Int 1) in
  for i = 1 to (Array.length t) - 1 do
    for j = 0 to amount do
      if j < coins.(i) then
        t.(i).(j) <- t.(i-1).(j)
      else
        t.(i).(j) <- t.(i-1).(j) +/ t.(i).(j - coins.(i))
    done
  done;
  t

let _ =
  let t = make_dyn_table 200 [|1;2;5;10;20;50;100;200|] in
  let last_row = Array.length t - 1 in
  let last_col = Array.length t.(last_row) - 1 in
  Printf.printf "%s\n" (string_of_num (t.(last_row).(last_col)))

This executes in ~8ms on my laptop. If I increase the amount from 200 pence to one million, the program still finds an answer in less than two seconds.

I translated the program to Haskell (which was definitely not fun in itself), and though it terminates with the right answer for 200 pence, if I increase that number to 10000, my laptop comes to a screeching halt (lots of thrashing). Here's the code:

import Data.Array

createDynTable :: Int -> Array Int Int -> Array (Int, Int) Int
createDynTable amount coins =
    let numCoins = (snd . bounds) coins
        t = array ((0, 0), (numCoins, amount))
            [((i, j), 1) | i <- [0 .. numCoins], j <- [0 .. amount]]
    in t

populateDynTable :: Array (Int, Int) Int -> Array Int Int -> Array (Int, Int) Int
populateDynTable t coins =
    go t 1 0
        where go t i j
                 | i > maxX = t
                 | j > maxY = go t (i+1) 0
                 | j < coins ! i = go (t // [((i, j), t ! (i-1, j))]) i (j+1)
                 | otherwise = go (t // [((i, j), t!(i-1,j) + t!(i, j - coins!i))]) i (j+1)
              ((_, _), (maxX, maxY)) = bounds t

changeCombinations amount coins =
    let coinsArray = listArray (0, length coins - 1) coins
        dynTable = createDynTable amount coinsArray
        dynTable' = populateDynTable dynTable coinsArray
        ((_, _), (i, j)) = bounds dynTable
    in
      dynTable' ! (i, j)

main =
    print $ changeCombinations 200 [1,2,5,10,20,50,100,200]

I'd love to hear from somebody who knows Haskell well why the performance of this solution is so bad.

回答1:

Haskell is pure. The purity means that values are immutable, and thus in the step

j < coins ! i = go (t // [((i, j), t ! (i-1, j))]) i (j+1)

you create an entire new array for each entry you update. That's already very expensive for a small amount like £2, but it becomes utterly obscene for an amount of £100.

Furthermore, the arrays are boxed, that means they contain pointers to the entries, which worsens locality, uses more storage, and allows thunks to be built up that are also slower to evaluate when they finally are forced.

The used algorithm depends on a mutable data structure for its efficiency, but the mutability is confined to the computation, so we can use what is intended to allow safely shielded computations with temporarily mutable data, the ST state transformer monad family, and the associated [unboxed, for efficiency] arrays.

Give me half an hour or so to translate the algorithm into code using STUArrays, and you'll get a Haskell version that is not too ugly, and ought to perform comparably to the O'Caml version (some more or less constant factor is expected for the difference, whether it's larger or smaller than 1, I don't know).

Here it is:

module Main (main) where

import System.Environment (getArgs)

import Data.Array.ST
import Control.Monad.ST
import Data.Array.Unboxed

standardCoins :: [Int]
standardCoins = [1,2,5,10,20,50,100,200]

changeCombinations :: Int -> [Int] -> Int
changeCombinations amount coins = runST $ do
    let coinBound = length coins - 1
        coinsArray :: UArray Int Int
        coinsArray = listArray (0, coinBound) coins
    table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STUArray s (Int,Int) Int)
    let go i j
            | i > coinBound = readArray table (coinBound,amount)
            | j > amount   = go (i+1) 0
            | j < coinsArray ! i = do
                v <- readArray table (i-1,j)
                writeArray table (i,j) v
                go i (j+1)
            | otherwise = do
                v <- readArray table (i-1,j)
                w <- readArray table (i, j - coinsArray!i)
                writeArray table (i,j) (v+w)
                go i (j+1)
    go 1 0

main :: IO ()
main = do
    args <- getArgs
    let amount = case args of
                   a:_ -> read a
                   _   -> 200
    print $ changeCombinations amount standardCoins

runs in not too shabby time,

$ time ./mutArr
73682

real    0m0.002s
user    0m0.000s
sys     0m0.001s
$ time ./mutArr 1000000
986687212143813985

real    0m0.439s
user    0m0.128s
sys     0m0.310s

and uses checked array accesses, using unchecked accesses, the time could be somewhat reduced.


Ah, I just learned that your O'Caml code uses arbitrary precision integers, so using Int in Haskell puts O'Caml at an unfair disadvantage. The changes necessary to calculate the results with arbitrary precision Integers are minmal,

$ diff mutArr.hs mutArrIgr.hs
12c12
< changeCombinations :: Int -> [Int] -> Int
---
> changeCombinations :: Int -> [Int] -> Integer
17c17
<     table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STUArray s (Int,Int) Int)
---
>     table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STArray s (Int,Int) Integer)
28c28
<                 writeArray table (i,j) (v+w)
---
>                 writeArray table (i,j) $! (v+w)

only two type signatures needed to be adapted - the array necessarily becomes boxed, so we need to make sure we're not writing thunks to the array in line 28, and

$ time ./mutArrIgr 
73682

real    0m0.002s
user    0m0.000s
sys     0m0.002s
$ time ./mutArrIgr 1000000
99341140660285639188927260001

real    0m1.314s
user    0m1.157s
sys     0m0.156s

the computation with the large result that overflowed for Ints takes noticeably longer, but as expected comparable to the O'Caml.


Spending some time understanding the O'Caml, I can offer a closer, a bit shorter, and arguably nicer translation:

module Main (main) where

import System.Environment (getArgs)

import Data.Array.ST
import Control.Monad.ST
import Data.Array.Unboxed
import Control.Monad (forM_)

standardCoins :: [Int]
standardCoins = [1,2,5,10,20,50,100,200]

changeCombinations :: Int -> [Int] -> Integer
changeCombinations amount coins = runST $ do
    let coinBound = length coins - 1
        coinsArray :: UArray Int Int
        coinsArray = listArray (0, coinBound) coins
    table <- newArray((0,0),(coinBound, amount)) 1 :: ST s (STArray s (Int,Int) Integer)
    forM_ [1 .. coinBound] $ \i ->
        forM_ [0 .. amount] $ \j ->
            if j < coinsArray!i
              then do
                  v <- readArray table (i-1,j)
                  writeArray table (i,j) v
              else do
                v <- readArray table (i-1,j)
                w <- readArray table (i, j - coinsArray!i)
                writeArray table (i,j) $! (v+w)
    readArray table (coinBound,amount)

main :: IO ()
main = do
    args <- getArgs
    let amount = case args of
                   a:_ -> read a
                   _   -> 200
    print $ changeCombinations amount standardCoins

that runs about equally fast:

$ time ./mutArrIgrM 1000000
99341140660285639188927260001

real    0m1.440s
user    0m1.273s
sys     0m0.164s


回答2:

You could take advantage of Haskell being lazy and not schedule the array filling yourself, but instead relying on lazy evaluation to do it in the right order. (For large inputs you'll need to increase the stack size.)

import Data.Array

createDynTable :: Integer -> Array Int Integer -> Array (Int, Integer) Integer
createDynTable amount coins =
    let numCoins = (snd . bounds) coins
        t = array ((0, 0), (numCoins, amount))
            [((i, j), go i j) | i <- [0 .. numCoins], j <- [0 .. amount]]
        go i j | i == 0        = 1
               | j < coins ! i = t ! (i-1, j)
               | otherwise     = t ! (i-1, j) + t ! (i, j - coins!i)
    in t


changeCombinations amount coins =
    let coinsArray = listArray (0, length coins - 1) coins
        dynTable = createDynTable amount coinsArray
        ((_, _), (i, j)) = bounds dynTable
    in
       dynTable ! (i, j)

main =
    print $ changeCombinations 200 [1,2,5,10,20,50,100,200]