Reducing space usage of depth first tree traversal

2019-05-11 18:24发布

问题:

In Haskell, one can do filters, sums, etc on infinite lists in constant space, because Haskell only produces list nodes when needed, and garbage collects ones it's finished with.

I'd like this to work with infinite trees.

Below is a rather silly program that generates a infinite binary tree with nodes representing the natural numbers.

I've then written a function that does a depth first traversal of this tree, spitting out the nodes at a particular level.

Then I've done a quick sum on the nodes divisable by 5.

In theory, this algorithm could be implemented in O(n) space for an n depth tree of O(2^n) nodes. Just generate the tree on the fly, removing the nodes you've already completed processing.

Haskell does generate the tree on the fly, but doesn't garbage collect the nodes it seems.

Below is the code, I'd like to see code with a similar effect but that doesn't require O(2^n) space.

import Data.List (foldl')

data Tree = Tree Int Tree Tree

tree n = Tree n (tree (2 * n)) (tree (2 * n + 1))
treeOne = tree 1

depthNTree n x = go n x id [] where
  go :: Int -> Tree -> ([Int] -> [Int]) -> [Int] -> [Int]
  go 0 (Tree x _ _) acc rest = acc (x:rest)
  go n (Tree _ left right) acc rest = t2 rest where 
    t1 = go (n - 1) left acc
    t2 = go (n - 1) right t1

main = do
  x <- getLine
  print . foldl' (+) 0 . filter (\x -> x `rem` 5 == 0) $ depthNTree (read x) treeOne

回答1:

Your depthNTree uses 2^n space because you keep the left subtree around through t1 while you're traversing the right subtree. The recursive call on the right subtree should contain no reference to the left, as a necessary condition for incrementally garbage collected traversals.

The naive version works acceptably in this example:

depthNTree n t = go n t where
  go 0 (Tree x _ _) = [x]
  go n (Tree _ l r) = go (n - 1) l ++ go (n - 1) r

Now main with input 24 uses 2 MB space, while the original version used 1820 MB. The optimal solution here is similar as above, except it uses difference lists:

depthNTree n t = go n t [] where
  go 0 (Tree x _ _) = (x:)
  go n (Tree _ l r) = go (n - 1) l . go (n - 1) r

This isn't much faster than the plain list version in many cases, because with tree-depths around 20-30 the left nesting of ++ isn't very costly. The difference becomes more pronounced if we use large tree depths:

print $ sum $ take 10 $ depthNTree 1000000 treeOne

On my computer, this runs in 0.25 secs with difference lists and 1.6 secs with lists.