Efficient table for Dynamic Programming in Haskell

2019-03-08 16:14发布

问题:

I've coded up the 0-1 Knapsack problem in Haskell. I'm fairly proud about the laziness and level of generality achieved so far.

I start by providing functions for creating and dealing with a lazy 2d matrix.

mkList f = map f [0..]
mkTable f = mkList (\i -> mkList (\j -> f i j))

tableIndex table i j = table !! i !! j

I then make a specific table for a given knapsack problem

knapsackTable = mkTable f
    where f 0 _ = 0
          f _ 0 = 0
          f i j | ws!!i > j = leaveI
                | otherwise = max takeI leaveI
              where takeI  = tableIndex knapsackTable (i-1) (j-(ws!!i)) + vs!!i
                    leaveI = tableIndex knapsackTable (i-1) j

-- weight value pairs; item i has weight ws!!i and value vs!!i
ws  = [0,1,2, 5, 6, 7] -- weights
vs  = [0,1,7,11,21,31] -- values

And finish off with a couple helper functions for looking at the table

viewTable table maxI maxJ = take (maxI+1) . map (take (maxJ+1)) $ table
printTable table maxI maxJ = mapM_ print $ viewTable table maxI maxJ

This much was pretty easy. But I want to take it a step further.

I want a better data structure for the table. Ideally, it should be

  • Unboxed (immutable) [edit] never mind this
  • Lazy
  • Unbounded
  • O(1) time to construct
  • O(1) time complexity for looking up a given entry,
    (more realistically, at worst O(log n), where n is i*j for looking up the entry at row i, column j)

Bonus points if you can explain why/how your solution satisfies these ideals.

Also bonus points if you can further generalize knapsackTable, and prove that it is efficient.

In improving the data structure you should try to satisfy the following goals:

  • If I ask for the solution where the maximum weight is 10 (in my current code, that would be indexTable knapsackTable 5 10, the 5 means include items 1-5) only the minimal amount of work necessary should be performed. Ideally this means no O(i*j) work for forcing the spine of each row of the table to necessary column length. You could say this isn't "true" DP, if you believe DP means evaluating the entirety of the table.
  • If I ask for the entire table to be printed (something like printTable knapsackTable 5 10), the values of each entry should be computed once and only once. The values of a given cell should depend on the values of other cells (DP style: the idea being, never recompute the same subproblem twice)

Ideas:

  • Data.Array bounded :(
  • UArray strict :(
  • Memoization techniques (SO question about DP in Haskell) this might work

Answers that make some compromises to my stated ideals will be upvoted (by me, anyways) as long as they are informative. The answer with the least compromises will probably be the "accepted" one.

回答1:

First, your criterion for an unboxed data structure is probably a bit mislead. Unboxed values must be strict, and they have nothing to do with immutability. The solution I'm going to propose is immutable, lazy, and boxed. Also, I'm not sure in what way you are wanting construction and querying to be O(1). The structure I'm proposing is lazily constructed, but because it's potentially unbounded, its full construction would take infinite time. Querying the structure will take O(k) time for any particular key of size k, but of course the value you're looking up may take further time to compute.

The data structure is a lazy trie. I'm using Conal Elliott's MemoTrie library in my code. For genericity, it takes functions instead of lists for the weights and values.

knapsack :: (Enum a, Num w, Num v, Num a, Ord w, Ord v, HasTrie a, HasTrie w) =>
            (a -> w) -> (a -> v) -> a -> w -> v
knapsack weight value = knapsackMem
  where knapsackMem = memo2 knapsack'
        knapsack' 0 w = 0
        knapsack' i 0 = 0
        knapsack' i w
          | weight i > w = knapsackMem (pred i) w
          | otherwise = max (knapsackMem (pred i) w)
                        (knapsackMem (pred i) (w - weight i)) + value i

Basically, it's implemented as a trie with a lazy spine and lazy values. It's bounded only by the key type. Because the entire thing is lazy, its construction before forcing it with queries is O(1). Each query forces a single path down the trie and its value, so it's O(1) for a bounded key size O(log n). As I already said, it's immutable, but not unboxed.

It will share all work in the recursive calls. It doesn't actually allow you to print the trie directly, but something like this should not do any redundant work:

mapM_ (print . uncurry (knapsack ws vs)) $ range ((0,0), (i,w))


回答2:

Unboxed implies strict and bounded. Anything 100% Unboxed cannot be Lazy or Unbounded. The usual compromise is embodied in converting [Word8] to Data.ByteString.Lazy where there are unboxed chunks (strict ByteString) which are linked lazily together in an unbounded way.

A much more efficient table generator (enhanced to track individual items) could be made using "scanl", "zipWith", and my "takeOnto". This effectively avoid using (!!) while creating the table:

import Data.List(sort,genericTake)

type Table = [ [ Entry ] ]

data Entry = Entry { bestValue :: !Integer, pieces :: [[WV]] }
  deriving (Read,Show)

data WV = WV { weight, value :: !Integer }
  deriving (Read,Show,Eq,Ord)

instance Eq Entry where
  (==) a b = (==) (bestValue a) (bestValue b)

instance Ord Entry where
  compare a b = compare (bestValue a) (bestValue b)

solutions :: Entry -> Int
solutions = length . filter (not . null) . pieces

addItem :: Entry -> WV -> Entry
addItem e wv = Entry { bestValue = bestValue e + value wv, pieces = map (wv:) (pieces e) }

-- Utility function for improve
takeOnto :: ([a] -> [a]) -> Integer -> [a] -> [a]
takeOnto endF = go where
  go n rest | n <=0 = endF rest
            | otherwise = case rest of
                            (x:xs) -> x : go (pred n) xs
                            [] -> error "takeOnto: unexpected []"

improve oldList wv@(WV {weight=wi,value = vi}) = newList where
  newList | vi <=0 = oldList
          | otherwise = takeOnto (zipWith maxAB oldList) wi oldList
  -- Dual traversal of index (w-wi) and index w makes this a zipWith
  maxAB e2 e1 = let e2v = addItem e2 wv
                in case compare e1 e2v of
                     LT -> e2v
                     EQ -> Entry { bestValue = bestValue e1
                                 , pieces = pieces e1 ++ pieces e2v }
                     GT -> e1

-- Note that the returned table is finite
-- The dependence on only the previous row makes this a "scanl" operation
makeTable :: [Int] -> [Int] -> Table
makeTable ws vs =
  let wvs = zipWith WV (map toInteger ws) (map toInteger vs)
      nil = repeat (Entry { bestValue = 0, pieces = [[]] })
      totW = sum (map weight wvs)
  in map (genericTake (succ totW)) $ scanl improve nil wvs

-- Create specific table, note that weights (1+7) equal weight 8
ws, vs :: [Int]
ws  = [2,3, 5, 5, 6, 7] -- weights
vs  = [1,7,8,11,21,31] -- values

t = makeTable ws vs

-- Investigate table

seeTable = mapM_ seeBestValue t
  where seeBestValue row = mapM_ (\v -> putStr (' ':(show (bestValue v)))) row >> putChar '\n'

ways = mapM_ seeWays t
  where seeWays row = mapM_ (\v -> putStr (' ':(show (solutions v)))) row >> putChar '\n'

-- This has two ways of satisfying a bestValue of 8 for 3 items up to total weight 5
interesting = print (t !! 3 !! 5) 


回答3:

Lazy storable vectors: http://hackage.haskell.org/package/storablevector

Unbounded, lazy, O(chunksize) time to construct, O(n/chunksize) indexing, where chunksize can be sufficiently large for any given purpose. Basically a lazy list with some significant constant factor benifits.



回答4:

To memoize functions, I recommend a library like Luke Palmer's memo combinators. The library uses tries, which are unbounded and have O(key size) lookup. (In general, you can't do better than O(key size) lookup because you always have to touch every bit of the key.)

knapsack :: (Int,Int) -> Solution
knapsack = memo f
    where
    memo    = pair integral integral
    f (i,j) = ... knapsack (i-b,j) ...


Internally, the integral combinator probably builds an infinite data structure

data IntTrie a = Branch IntTrie a IntTrie

integral f = \n -> lookup n table
     where
     table = Branch (\n -> f (2*n)) (f 0) (\n -> f (2*n+1))

Lookup works like this:

lookup 0 (Branch l a r) = a
lookup n (Branch l a r) = if even n then lookup n2 l else lookup n2 r
     where n2 = n `div` 2

There are other ways to build infinite tries, but this one is popular.



回答5:

Why won't you use Data.Map putting the other Data.Map into it? As far as I know it's quite fast. It wouldn't be lazy though.

More than that, you can implement Ord typeclass for you data

data Index = Index Int Int

and put a two dimensional index directly as a key. You can achieve laziness by generating this map as a list and then just use

fromList [(Index 0 0, value11), (Index 0 1, value12), ...]