Dynamic programming in the functional paradigm

2019-01-30 12:02发布

问题:

I'm looking at Problem thirty one on Project Euler, which asks, how many different ways are there of making £2 using any number of coins of 1p, 2p, 5p, 10p, 20p, 50p, £1 (100p) and £2 (200p).

There are recursive solutions, such as this one in Scala (credit to Pavel Fatin)

def f(ms: List[Int], n: Int): Int = ms match {
  case h :: t =>
    if (h > n) 0 else if (n == h) 1 else f(ms, n - h) + f(t, n)
  case _ => 0
} 
val r = f(List(1, 2, 5, 10, 20, 50, 100, 200), 200)

and although it runs fast enough, it's relatively inefficient, calling the f function around 5.6 million times.

I saw someone else's solution in Java which was programmed dynamically (credit to wizeman from Portugal)

final static int TOTAL = 200;

public static void main(String[] args) {
    int[] coins = {1, 2, 5, 10, 20, 50, 100, 200};
    int[] ways = new int[TOTAL + 1];
    ways[0] = 1;

    for (int coin : coins) {
        for (int j = coin; j <= TOTAL; j++) {
            ways[j] += ways[j - coin];
        }
    }

    System.out.println("Result: " + ways[TOTAL]);
}

This is much more efficient and passes the inner loop only 1220 times.

While I could obviously translate this more or less verbatim into Scala using Array objects, is there an idiomatic functional way to do this using immutable data structures, preferably with similar conciseness and performance?

I have tried and become stuck trying to recursively update a List before deciding I'm probably just approaching it the wrong way.

回答1:

Whenever some part of a list of data is computed based on a previous element, I think of Stream recursion. Unfortunately, such recursion cannot happen inside method definitions or functions, so I had to turn a function into a class to make it work.

class IterationForCoin(stream: Stream[Int], coin: Int) {
  val (lower, higher) = stream splitAt coin
  val next: Stream[Int] = lower #::: (higher zip next map { case (a, b) => a + b })
}
val coins = List(1, 2, 5, 10, 20, 50, 100, 200)
val result = coins.foldLeft(1 #:: Stream.fill(200)(0)) { (stream, coin) =>
  new IterationForCoin(stream, coin).next
} last

The definitions of lower and higher are not necessary -- I could easily replace them with stream take coin and stream drop coin, but I think it's a little clearer (and more efficient) this way.



回答2:

I don't know enough about Scala to comment specifically on that, but the typical way to translation a DP solution in to a recursive one is to memoization (use http://en.wikipedia.org/wiki/Memoization). This is basically caching the result of your function for all values of the domain

I found this as well http://michid.wordpress.com/2009/02/23/function_mem/. HTH



回答3:

Functional dynamic programming can actually be really beautiful in a lazy language, such as Haskell (there's an article on it on the Haskell wiki). This is a dynamic programming solution to the problem:

import Data.Array

makeChange :: [Int] -> Int -> Int
makeChange coinsList target = arr ! (0,target)
  where numCoins = length coinsList
        coins    = listArray (0,numCoins-1) coinsList
        bounds   = ((0,0),(numCoins,target))
        arr      = listArray bounds . map (uncurry go) $ range bounds
        go i n   | i == numCoins = 0
                 | otherwise     = let c = coins ! i
                                   in case c `compare` n of
                                        GT -> 0
                                        EQ -> 1
                                        LT -> (arr ! (i, n-c)) + (arr ! (i+1,n))

main :: IO ()
main = putStrLn $  "Project Euler Problem 31: "
                ++ show (makeChange [1, 2, 5, 10, 20, 50, 100, 200] 200)

Admittedly, this uses O(cn) memory, where c is the number of coins and n is the target (as opposed to the Java version's O(n) memory); to get that, you'd have to use some technique of capturing mutable state (probably an STArray). However, they both run in O(cn) time. The idea is to encode the recursive solution almost directly recursively, but instead of recursing within go, we look up the answer in the array. And how do we construct the array? By calling go on every index. Since Haskell is lazy, it only computes things when asked to, so the order-of-evaluation stuff necessary for dynamic programming is all handled transparently.

And thanks to Scala's by-name parameters and lazy vals, we can mimic this solution in Scala:

class Lazy[A](x: => A) {
  lazy val value = x
}

object Lazy {
  def apply[A](x: => A) = new Lazy(x)
  implicit def fromLazy[A](z: Lazy[A]): A = z.value
  implicit def toLazy[A](x: => A): Lazy[A] = Lazy(x)
}

import Lazy._

def makeChange(coins: Array[Int], target: Int): Int = {
  val numCoins = coins.length
  lazy val arr: Array[Array[Lazy[Int]]]
    = Array.tabulate(numCoins+1,target+1) { (i,n) =>
        if (i == numCoins) {
          0
        } else {
          val c = coins(i)
          if (c > n)
            0
          else if (c == n)
            1
          else
            arr(i)(n-c) + arr(i+1)(n)
        }
      }
  arr(0)(target)
}

// makeChange(Array(1, 2, 5, 10, 20, 50, 100, 200), 200)

The Lazy class encodes values which are only evaluated on demand, and then we build an array full of them. Both of these solutions work for a target value of 10000 practically instantly, although go much larger and you'll run into either integer overflow or (in Scala, at least) a stack overflow.



回答4:

Ok, here's the memoized version of Pavel Fatin's code. I'm using Scalaz memoization stuff, though it's really simple to write your own memoization class.

import scalaz._
import Scalaz._

val memo = immutableHashMapMemo[(List[Int], Int), Int]
def f(ms: List[Int], n: Int): Int = ms match {
  case h :: t =>
    if (h > n) 0 else if (n == h) 1 else memo((f _).tupled)(ms, n - h) + memo((f _).tupled)(t, n)
  case _ => 0
} 
val r = f(List(1, 2, 5, 10, 20, 50, 100, 200), 200)


回答5:

For the sake of completeness, here is a slight variant of the answer above that doesn't use Stream:

object coins {
  val coins = List(1, 2, 5, 10, 20, 50, 100, 200)
  val total = 200
  val result = coins.foldLeft(1 :: List.fill(total)(0)) { (list, coin) =>
    new IterationForCoin(list, coin).next(total)
  } last
}

class IterationForCoin(list: List[Int], coin: Int) {
  val (lower, higher) = list splitAt coin
  def next (total: Int): List[Int] = {
    val listPart = if (total>coin) next(total-coin) else lower
    lower ::: (higher zip listPart map { case (a, b) => a + b })
  }
}