Why must we use state monad instead of passing sta

2019-03-19 05:57发布

问题:

Can someone show a simple example where state monad can be better than passing state directly?

bar1 (Foo x) = Foo (x + 1)

vs

bar2 :: State Foo Foo
bar2 = do
  modify (\(Foo x) -> Foo (x + 1))
  get

回答1:

State passing is often tedious, error-prone, and hinders refactoring. For example, try labeling a binary tree or rose tree in postorder:

data RoseTree a = Node a [RoseTree a] deriving (Show)

postLabel :: RoseTree a -> RoseTree Int
postLabel = fst . go 0 where
  go i (Node _ ts) = (Node i' ts', i' + 1) where

    (ts', i') = gots i ts

    gots i []     = ([], i)
    gots i (t:ts) = (t':ts', i'') where
      (t', i')   = go i t
      (ts', i'') = gots i' ts

Here I had to manually label states in the right order, pass the correct states along, and had to make sure that both the labels and child nodes are in the right order in the result (note that naive use of foldr or foldl for the child nodes could have easily led to incorrect behavior).

Also, if I try to change the code to preorder, I have to make changes that are easy to get wrong:

preLabel :: RoseTree a -> RoseTree Int
preLabel = fst . go 0 where
  go i (Node _ ts) = (Node i ts', i') where -- first change

    (ts', i') = gots (i + 1) ts -- second change

    gots i []     = ([], i)
    gots i (t:ts) = (t':ts', i'') where
      (t', i')   = go i t
      (ts', i'') = gots i' ts

Examples:

branch = Node ()
nil  = branch []
tree = branch [branch [nil, nil], nil]
preLabel tree == Node 0 [Node 1 [Node 2 [],Node 3 []],Node 4 []]
postLabel tree == Node 4 [Node 2 [Node 0 [],Node 1 []],Node 3 []]

Contrast the state monad solution:

import Control.Monad.State
import Control.Applicative

postLabel' :: RoseTree a -> RoseTree Int
postLabel' = (`evalState` 0) . go where
  go (Node _ ts) = do
    ts' <- traverse go ts
    i   <- get <* modify (+1)
    pure (Node i ts')

preLabel' :: RoseTree a -> RoseTree Int
preLabel' = (`evalState` 0) . go where
  go (Node _ ts) = do
    i   <- get <* modify (+1)
    ts' <- traverse go ts
    pure (Node i ts')

Not only is this code more succinct and easier to write correctly, the logic that results in pre- or postorder labeling is far more transparent.


PS: bonus applicative style:

postLabel' :: RoseTree a -> RoseTree Int
postLabel' = (`evalState` 0) . go where
  go (Node _ ts) =
    flip Node <$> traverse go ts <*> (get <* modify (+1))

preLabel' :: RoseTree a -> RoseTree Int
preLabel' = (`evalState` 0) . go where
  go (Node _ ts) =
    Node <$> (get <* modify (+1)) <*> traverse go ts


回答2:

As an example to my comment above, you can write code using the State monad like

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

import Data.Text (Text)
import qualified Data.Text as Text
import Control.Monad.State

data MyState = MyState
    { _count :: Int
    , _messages :: [Text]
    } deriving (Eq, Show)
makeLenses ''MyState

type App = State MyState

incrCnt :: App ()
incrCnt = modify (\my -> my & count +~ 1)

logMsg :: Text -> App ()
logMsg msg = modify (\my -> my & messages %~ (++ [msg]))

logAndIncr :: Text -> App ()
logAndIncr msg = do
    incrCnt
    logMsg msg

app :: App ()
app = do
    logAndIncr "First step"
    logAndIncr "Second step"
    logAndIncr "Third step"
    logAndIncr "Fourth step"
    logAndIncr "Fifth step"

Note that using additional operators from Control.Lens also lets you write incrCnt and logMsg as

incrCnt = count += 1

logMsg msg = messages %= (++ [msg])

which is another benefit of using State in combination with the lens library, but for the sake of comparison I'm not using them in this example. To write the equivalent code above with just argument passing it would look more like

incrCnt :: MyState -> MyState
incrCnt my = my & count +~ 1

logMsg :: MyState -> Text -> MyState
logMsg my msg = my & messages %~ (++ [msg])

logAndIncr :: MyState -> Text -> MyState
logAndIncr my msg =
    let incremented = incrCnt my
        logged = logMsg incremented msg
    in logged

At this point it isn't too bad, but once we get to the next step I think you'll see where the code duplication really comes in:

app :: MyState -> MyState
app initial =
    let first_step  = logAndIncr initial     "First step"
        second_step = logAndIncr first_step  "Second step"
        third_step  = logAndIncr second_step "Third step"
        fourth_step = logAndIncr third_step  "Fourth step"
        fifth_step  = logAndIncr fourth_step "Fifth step"
    in fifth_step

Another benefit of wrapping this up in a Monad instance is that you can use the full power of Control.Monad and Control.Applicative with it:

app = mapM_ logAndIncr [
    "First step",
    "Second step",
    "Third step",
    "Fourth step",
    "Fifth step"
    ]

Which allows for much more flexibility when processing values calculated at runtime compared to static values.

The difference between manual state passing and using the State monad is simply that the State monad is an abstraction over the manual process. It also happens to fit several other widely used more general abstractions, like Monad, Applicative, Functor, and a few others. If you also use the StateT transformer then you can compose these operations with other monads, such as IO. Can you do all of this without State and StateT? Of course you can, and there's no one stopping you from doing so, but the point is that State abstracts this pattern out and gives you access to a huge toolbox of more general tools. Also, a small modification to the types above makes the same functions work in multiple contexts:

incrCnt :: MonadState MyState m => m ()
logMsg :: MonadState MyState m => Text -> m ()
logAndIncr :: MonadState MyState m => Text -> m ()

These will now work with App, or with StateT MyState IO, or any other monad stack with a MonadState implementation. It makes it significantly more reusable than simple argument passing, which is only possible through the abstraction that is StateT.



回答3:

In my experience, the point of many Monads doesn't really click until you get into larger examples, so here is an example use of State (well, StateT ... IO) to parse an incoming request to a web service.

The pattern is that this web service can be called with a bunch of options of different types, though all except for one of the options have decent defaults. If I get a incoming JSON request with an unknown key value, I should abort with an appropriate message. I use the state to keep track of what the current config is, and what the remainder of the JSON request is, along with a bunch of accessor methods.

(Based on code currently in production, with the names of everything changed and the details of what this service actually does obscured)

{-# LANGUAGE OverloadedStrings #-}

module XmpConfig where

import Data.IORef
import Control.Arrow (first)
import Control.Monad
import qualified Data.Text as T
import Data.Aeson hiding ((.=))
import qualified Data.HashMap.Strict as MS
import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.State (execStateT, StateT, gets, modify)
import qualified Data.Foldable as DF
import Data.Maybe (fromJust, isJust)

data Taggy = UseTags Bool | NoTags
newtype Locale = Locale String

data MyServiceConfig = MyServiceConfig {
    _mscTagStatus :: Taggy
  , _mscFlipResult :: Bool
  , _mscWasteTime :: Bool
  , _mscLocale :: Locale
  , _mscFormatVersion :: Int
  , _mscJobs :: [String]
  }

baseWebConfig :: IO (IORef [String], IORef [String], MyServiceConfig)
baseWebConfig = do
  infoRef <- newIORef []
  warningRef <- newIORef []
  let cfg = MyServiceConfig {
        _mscTagStatus = NoTags
        , _mscFlipResult = False
        , _mscWasteTime = False
        , _mscLocale = Locale "en-US"
        , _mscFormatVersion = 1
        , _mscJobs = []
        }
  return (infoRef, warningRef, cfg)

parseLocale :: T.Text -> Maybe Locale
parseLocale = Just . Locale . T.unpack  -- The real thing does more

parseJSONReq :: MS.HashMap T.Text Value ->
                IO (IORef [String], IORef [String], MyServiceConfig)
parseJSONReq m = liftM snd
                 (baseWebConfig >>= (\c -> execStateT parse' (m, c)))
  where
    parse' :: StateT (MS.HashMap T.Text Value,
                      (IORef [String], IORef [String], MyServiceConfig))
              IO ()
    parse' = do
      let addWarning s = do let snd3 (_, b, _) = b
                            r <- gets (snd3 . snd)
                            liftIO $ modifyIORef r (++ [s])
          -- These two functions suck a key/value off the input map and
          -- pass the value on to the handler "h"
          onKey      k h = onKeyMaybe k $ DF.mapM_ h
          onKeyMaybe k h = do myb <- gets fst
                              modify $ first $ MS.delete k
                              h (MS.lookup k myb)
          -- Access the "lns" field of the configuration
          config setter = modify (\(a, (b, c, d)) -> (a, (b, c, setter d)))

      onKey "tags" $ \x -> case x of
        Bool True ->       config $ \c -> c {_mscTagStatus = UseTags False}
        String "true" ->   config $ \c -> c {_mscTagStatus = UseTags False}
        Bool False ->      config $ \c -> c {_mscTagStatus = NoTags}
        String "false" ->  config $ \c -> c {_mscTagStatus = NoTags}
        String "inline" -> config $ \c -> c {_mscTagStatus = UseTags True}
        q -> addWarning ("Bad value ignored for tags: " ++ show q)
      onKey "reverse" $ \x -> case x of
        Bool r -> config $ \c -> c {_mscFlipResult = r}
        q -> addWarning ("Bad value ignored for reverse: " ++ show q)
      onKey "spin" $ \x -> case x of
        Bool r -> config $ \c -> c {_mscWasteTime = r}
        q -> addWarning ("Bad value ignored for spin: " ++ show q)
      onKey "language" $ \x -> case x of
        String s | isJust (parseLocale s) ->
          config $ \c -> c {_mscLocale = fromJust $ parseLocale s}
        q -> addWarning ("Bad value ignored for language: " ++ show q)
      onKey "format" $ \x -> case x of
        Number 1 -> config $ \c -> c {_mscFormatVersion = 1}
        Number 2 -> config $ \c -> c {_mscFormatVersion = 2}
        q -> addWarning ("Bad value ignored for format: " ++ show q)
      onKeyMaybe "jobs" $ \p -> case p of
        Just (Array x) -> do q <- parseJobs x
                             config $ \c -> c {_mscJobs = q}
        Just (String "test") ->
          config $ \c -> c {_mscJobs = ["test1", "test2"]}
        Just other -> fail $ "Bad value for jobs: " ++ show other
        Nothing    -> fail "Missing value for jobs"
      m' <- gets fst
      unless (MS.null m') (fail $ "Unrecognized key(s): " ++ show (MS.keys m'))

    parseJobs :: (Monad m, DF.Foldable b) => b Value -> m [String]
    parseJobs = DF.foldrM (\a b -> liftM (:b) (parseJob a)) []
    parseJob :: (Monad m) => Value -> m String
    parseJob (String s) = return (T.unpack s)
    parseJob q = fail $ "Bad job value: " ++ show q