I'm attempting to structure an AST using the Free monad based on some helpful literature that I've read online.
I have some questions about working with these kinds of ASTs in practice, which I've boiled down to the following example.
Suppose my language allows for the following commands:
{-# LANGUAGE DeriveFunctor #-}
data Command next
= DisplayChar Char next
| DisplayString String next
| Repeat Int (Free Command ()) next
| Done
deriving (Eq, Show, Functor)
and I define the Free monad boilerplate manually:
displayChar :: Char -> Free Command ()
displayChar ch = liftF (DisplayChar ch ())
displayString :: String -> Free Command ()
displayString str = liftF (DisplayString str ())
repeat :: Int -> Free Command () -> Free Command ()
repeat times block = liftF (Repeat times block ())
done :: Free Command r
done = liftF Done
which allows me to specify programs like the following:
prog :: Free Command r
prog =
do displayChar 'A'
displayString "abc"
repeat 5 $
displayChar 'Z'
displayChar '\n'
done
Now, I'd like to execute my program, which seems simple enough.
execute :: Free Command r -> IO ()
execute (Free (DisplayChar ch next)) = putChar ch >> execute next
execute (Free (DisplayString str next)) = putStr str >> execute next
execute (Free (Repeat n block next)) = forM_ [1 .. n] (\_ -> execute block) >> execute next
execute (Free Done) = return ()
execute (Pure r) = return ()
and
λ> execute prog
AabcZZZZZ
Okay. That's all nice, but now I want to learn things about my AST, and execute transformations on it. Think like optimizations in a compiler.
Here's a simple one: If a Repeat
block only contains DisplayChar
commands, then I'd like to replace the whole thing with an appropriate DisplayString
. In other words,
I'd like to transform repeat 2 (displayChar 'A' >> displayChar 'B')
with displayString "ABAB"
.
Here's my attempt:
optimize c@(Free (Repeat n block next)) =
if all isJust charsToDisplay then
let chars = catMaybes charsToDisplay
in
displayString (concat $ replicate n chars) >> optimize next
else
c >> optimize next
where
charsToDisplay = project getDisplayChar block
optimize (Free (DisplayChar ch next)) = displayChar ch >> optimize next
optimize (Free (DisplayString str next)) = displayString str >> optimize next
optimize (Free Done) = done
optimize c@(Pure r) = c
getDisplayChar (Free (DisplayChar ch _)) = Just ch
getDisplayChar _ = Nothing
project :: (Free Command a -> Maybe u) -> Free Command a -> [Maybe u]
project f = maybes
where
maybes (Pure a) = []
maybes c@(Free cmd) =
let build next = f c : maybes next
in
case cmd of
DisplayChar _ next -> build next
DisplayString _ next -> build next
Repeat _ _ next -> build next
Done -> []
Observing the AST in GHCI shows that this work correctly, and indeed
λ> optimize $ repeat 3 (displayChar 'A' >> displayChar 'B')
Free (DisplayString "ABABAB" (Pure ()))
λ> execute . optimize $ prog
AabcZZZZZ
λ> execute prog
AabcZZZZZ
But I'm not happy. In my opinion, this code is repetitive. I have to define how to traverse through my AST every time I want to examine it, or define functions like my project
that give me a view into it. I have to do this same thing when I want to modify the tree.
So, my question: is this approach my only option? Can I pattern-match on my AST without dealing with tonnes of nesting? Can I traverse the tree in a consistent and generic way (maybe Zippers, or Traversable, or something else)? What approaches are commonly taken here?
The whole file is below:
{-# LANGUAGE DeriveFunctor #-}
module Main where
import Prelude hiding (repeat)
import Control.Monad.Free
import Control.Monad (forM_)
import Data.Maybe (catMaybes, isJust)
main :: IO ()
main = execute prog
prog :: Free Command r
prog =
do displayChar 'A'
displayString "abc"
repeat 5 $
displayChar 'Z'
displayChar '\n'
done
optimize c@(Free (Repeat n block next)) =
if all isJust charsToDisplay then
let chars = catMaybes charsToDisplay
in
displayString (concat $ replicate n chars) >> optimize next
else
c >> optimize next
where
charsToDisplay = project getDisplayChar block
optimize (Free (DisplayChar ch next)) = displayChar ch >> optimize next
optimize (Free (DisplayString str next)) = displayString str >> optimize next
optimize (Free Done) = done
optimize c@(Pure r) = c
getDisplayChar (Free (DisplayChar ch _)) = Just ch
getDisplayChar _ = Nothing
project :: (Free Command a -> Maybe u) -> Free Command a -> [Maybe u]
project f = maybes
where
maybes (Pure a) = []
maybes c@(Free cmd) =
let build next = f c : maybes next
in
case cmd of
DisplayChar _ next -> build next
DisplayString _ next -> build next
Repeat _ _ next -> build next
Done -> []
execute :: Free Command r -> IO ()
execute (Free (DisplayChar ch next)) = putChar ch >> execute next
execute (Free (DisplayString str next)) = putStr str >> execute next
execute (Free (Repeat n block next)) = forM_ [1 .. n] (\_ -> execute block) >> execute next
execute (Free Done) = return ()
execute (Pure r) = return ()
data Command next
= DisplayChar Char next
| DisplayString String next
| Repeat Int (Free Command ()) next
| Done
deriving (Eq, Show, Functor)
displayChar :: Char -> Free Command ()
displayChar ch = liftF (DisplayChar ch ())
displayString :: String -> Free Command ()
displayString str = liftF (DisplayString str ())
repeat :: Int -> Free Command () -> Free Command ()
repeat times block = liftF (Repeat times block ())
done :: Free Command r
done = liftF Done
Here's my take using syb (as mentioned on Reddit):
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveDataTypeable #-}
module Main where
import Prelude hiding (repeat)
import Data.Data
import Control.Monad (forM_)
import Control.Monad.Free
import Control.Monad.Free.TH
import Data.Generics (everywhere, mkT)
data CommandF next = DisplayChar Char next
| DisplayString String next
| Repeat Int (Free CommandF ()) next
| Done
deriving (Eq, Show, Functor, Data, Typeable)
makeFree ''CommandF
type Command = Free CommandF
execute :: Command () -> IO ()
execute = iterM handle
where
handle = \case
DisplayChar ch next -> putChar ch >> next
DisplayString str next -> putStr str >> next
Repeat n block next -> forM_ [1 .. n] (\_ -> execute block) >> next
Done -> return ()
optimize :: Command () -> Command ()
optimize = optimize' . optimize'
where
optimize' = everywhere (mkT inner)
inner :: Command () -> Command ()
-- char + char becomes string
inner (Free (DisplayChar c1 (Free (DisplayChar c2 next)))) = do
displayString [c1, c2]
next
-- char + string becomes string
inner (Free (DisplayChar c (Free (DisplayString s next)))) = do
displayString $ c : s
next
-- string + string becomes string
inner (Free (DisplayString s1 (Free (DisplayString s2 next)))) = do
displayString $ s1 ++ s2
next
-- Loop unrolling
inner f@(Free (Repeat n block next)) | n < 5 = forM_ [1 .. n] (\_ -> block) >> next
| otherwise = f
inner a = a
prog :: Command ()
prog = do
displayChar 'a'
displayChar 'b'
repeat 1 $ displayChar 'c' >> displayString "def"
displayChar 'g'
displayChar 'h'
repeat 10 $ do
displayChar 'i'
displayChar 'j'
displayString "klm"
repeat 3 $ displayChar 'n'
main :: IO ()
main = do
putStrLn "Original program:"
print prog
putStrLn "Evaluation of original program:"
execute prog
putStrLn "\n"
let opt = optimize prog
putStrLn "Optimized program:"
print opt
putStrLn "Evaluation of optimized program:"
execute opt
putStrLn ""
Output:
$ cabal exec runhaskell ast.hs
Original program:
Free (DisplayChar 'a' (Free (DisplayChar 'b' (Free (Repeat 1 (Free (DisplayChar 'c' (Free (DisplayString "def" (Pure ()))))) (Free (DisplayChar 'g' (Free (DisplayChar 'h' (Free (Repeat 10 (Free (DisplayChar 'i' (Free (DisplayChar 'j' (Free (DisplayString "klm" (Pure ()))))))) (Free (Repeat 3 (Free (DisplayChar 'n' (Pure ()))) (Pure ()))))))))))))))
Evaluation of original program:
abcdefghijklmijklmijklmijklmijklmijklmijklmijklmijklmijklmnnn
Optimized program:
Free (DisplayString "abcdefgh" (Free (Repeat 10 (Free (DisplayString "ijklm" (Pure ()))) (Free (DisplayString "nnn" (Pure ()))))))
Evaluation of optimized program:
abcdefghijklmijklmijklmijklmijklmijklmijklmijklmijklmijklmnnn
It might be possible to get rid of the *Free*s using GHC 7.8 Pattern Synonyms, but for some reason the above code only works using GHC 7.6, the Data instance of Free seems to be missing. Should look into that...
If your issue is with boilerplate, you won't get around it if you use Free
! You will always be stuck with an extra constructor on each level.
But on the flip side, if you are using Free
, you have a very easy way to generalize recursion over your data structure. You can write this all from scratch, but I used the recursion-schemes
package:
import Data.Functor.Foldable
data (:+:) f g a = L (f a) | R (g a) deriving (Functor, Eq, Ord, Show)
type instance Base (Free f a) = f :+: Const a
instance (Functor f) => Foldable (Free f a) where
project (Free f) = L f
project (Pure a) = R (Const a)
instance Functor f => Unfoldable (Free f a) where
embed (L f) = Free f
embed (R (Const a)) = Pure a
instance Functor f => Unfoldable (Free f a) where
embed (L f) = Free f
embed (R (Const a)) = Pure a
If you are unfamiliar with this (read the documentation), but basically all you need to know is project
takes some data, like Free f a
, and "un-nests" it by one level, producing something like (f :+: Const a) (Free f a)
. Now, you have given regular functions like fmap
, Data.Foldable.foldMap
, etc, access to the structure of your data, since the argument of the functor is the sub-tree.
Executing is very simple, although not much more concise:
execute :: Free Command r -> IO ()
execute = cata go where
go (L (DisplayChar ch next)) = putChar ch >> next
go (L (DisplayString str next)) = putStr str >> next
go (L (Repeat n block next)) = forM_ [1 .. n] (const $ execute block) >> next
go (L Done) = return ()
go (R _) = return ()
However, simplification becomes much easier. We can define simplification over all datatypes which have Foldable
and Unfoldable
instances:
reduce :: (Foldable t, Functor (Base t), Unfoldable t) => (t -> Maybe t) -> t -> t
reduce rule x = let y = embed $ fmap (reduce rule) $ project x in
case rule y of
Nothing -> y
Just y' -> y'
The simplification rule only needs to simplify one level of the AST (namely, the top-most level). Then, if the simplification can apply to the substructure, it will perform it there too. Note that the above reduce
works bottom up; you can also have a top down reduction:
reduceTD :: (Foldable t, Functor (Base t), Unfoldable t) => (t -> Maybe t) -> t -> t
reduceTD rule x = embed $ fmap (reduceTD rule) $ project y
where y = case rule x of
Nothing -> x
Just x' -> x'
Your example simplification rule can be written very simply:
getChrs :: (Command :+: Const ()) (Maybe String) -> Maybe String
getChrs (L (DisplayChar c n)) = liftA (c:) n
getChrs (L Done) = Just []
getChrs (R _) = Just []
getChrs _ = Nothing
optimize (Free (Repeat n dc next)) = do
chrs <- cata getChrs dc
return $ Free $ DisplayString (concat $ map (replicate n) chrs) next
optimize _ = Nothing
Because of the way you've defined your datatype, you don't have access to the 2nd arguement of Repeat
, so for things like repeat' 5 (repeat' 3 (displayChar 'Z')) >> done
, the inner repeat
can't be simplified. If this is a situation you expect to deal with, you either change your datatype and accept a lot more boilerplate, or write an exception:
reduceCmd rule (Free (Repeat n c r)) =
let x = Free (Repeat n (reduceCmd rule c) (reduceCmd rule r)) in
case rule x of
Nothing -> x
Just x' -> x'
reduceCmd rule x = embed $ fmap (reduceCmd rule) $ project x
Using recursion-schemes
or the like will probably make your code more easily extensible. But it isn't necessary by any means:
execute = iterM go where
go (DisplayChar ch next) = putChar ch >> next
go (DisplayString str next) = putStr str >> next
go (Repeat n block next) = forM_ [1 .. n] (const $ execute block) >> next
go Done = return ()
getChrs
can't access Pure
, and your programs will be of the form Free Command ()
, so before you apply it, you have to get replace ()
with Maybe String
.
getChrs :: Command (Maybe String) -> Maybe String
getChrs (DisplayChar c n) = liftA (c:) n
getChrs (DisplayString s n) = liftA (s++) n
getChrs Done = Just []
getChrs _ = Nothing
optimize :: Free Command a -> Maybe (Free Command a)
optimize (Free (Repeat n dc next)) = do
chrs <- iter getChrs $ fmap (const $ Just []) dc
return $ Free $ DisplayString (concat $ map (replicate n) chrs) next
optimize _ = Nothing
Note that reduce
is almost the exact same as before, except for two things: project
and embed
are replaced with pattern matching on Free
and Free
, respectively; and you need a separate case for Pure
. This should tell you that Foldable
and Unfoldable
generalize things that "look like" Free
.
reduce
:: Functor f =>
(Free f a -> Maybe (Free f a)) -> Free f a -> Free f a
reduce rule (Free x) = let y = Free $ fmap (reduce rule) $ x in
case rule y of
Nothing -> y
Just y' -> y'
reduce rule a@(Pure _) = case rule a of
Nothing -> a
Just b -> b
All the other functions are modified similarly.
Please don't think of zippers, traversals, SYB or lens until you've taken advantage of the standard features of Free
. Your execute
, optimize
and project
are just standard free monad recursion schemes which are already available in the package:
optimize :: Free Command a -> Free Command a
optimize = iterM $ \f -> case f of
c@(Repeat n block next) ->
let charsToDisplay = project getDisplayChar block in
if all isJust charsToDisplay then
let chars = catMaybes charsToDisplay in
displayString (concat $ replicate n chars) >> next
else
liftF c >> next
DisplayChar ch next -> displayChar ch >> next
DisplayString str next -> displayString str >> next
Done -> done
getDisplayChar :: Command t -> Maybe Char
getDisplayChar (DisplayChar ch _) = Just ch
getDisplayChar _ = Nothing
project' :: (Command [u] -> u) -> Free Command [u] -> [u]
project' f = iter $ \c -> f c : case c of
DisplayChar _ next -> next
DisplayString _ next -> next
Repeat _ _ next -> next
Done -> []
project :: (Command [u] -> u) -> Free Command a -> [u]
project f = project' f . fmap (const [])
execute :: Free Command () -> IO ()
execute = iterM $ \f -> case f of
DisplayChar ch next -> putChar ch >> next
DisplayString str next -> putStr str >> next
Repeat n block next -> forM_ [1 .. n] (\_ -> execute block) >> next
Done -> return ()
Since your components each have at most one continuation you can probably find a clever way to get rid of all those >> next
too.
You can certainly do this easier. There's still some work to be done because it won't perform a full optimization in the first pass, but after two passes it fully optimizes your example program. I'll leave that exercise up to you, but otherwise you can do this very simply with pattern matching on the optimizations you want to make. It's still a bit repetitive, but removes a lot of the complication you had:
optimize (Free (Repeat n block next)) = optimize (replicateM n block >> next)
optimize (Free (DisplayChar ch1 (Free (DisplayChar ch2 next)))) = optimize (displayString [ch1, ch2] >> next)
optimize (Free (DisplayChar ch (Free (DisplayString str next)))) = optimize (displayString (ch:str) >> next)
optimize (Free (DisplayString s1 (Free (DisplayString s2 next)))) = optimize (displayString (s1 ++ s2) >> next)
optimize (Free (DisplayString s (Free (DisplayChar ch next)))) = optimize (displayString (s ++ [ch]) >> next)
optimize (Free (DisplayChar ch next)) = displayChar ch >> optimize next
optimize (Free (DisplayString str next)) = displayString str >> optimize next
optimize (Free Done) = done
optimize c@(Pure r) = c
All I did was pattern match on repeat n (displayChar c)
, displayChar c1 >> displayChar c2
, displayChar c >> displayString s
, displayString s >> displayChar c
, and displayString s1 >> displayString s2
. There are other optimizations that can be done, but this was pretty easy and doesn't depend on scanning anything else, just iteratively stepping over the AST recursively optimizing.