(How) can I make this monadic bind tail-recursive?

2020-06-03 04:22发布

问题:

I have this monad called Desync -

[<AutoOpen>]
module DesyncModule =

    /// The Desync monad. Allows the user to define in a sequential style an operation that spans
    /// across a bounded number of events. Span is bounded because I've yet to figure out how to
    /// make Desync implementation tail-recursive (see note about unbounded recursion in bind). And
    /// frankly, I'm not sure if there is a tail-recursive implementation of it...
    type [<NoComparison; NoEquality>] Desync<'e, 's, 'a> =
        Desync of ('s -> 's * Either<'e -> Desync<'e, 's, 'a>, 'a>)

    /// Monadic return for the Desync monad.
    let internal returnM (a : 'a) : Desync<'e, 's, 'a> =
        Desync (fun s -> (s, Right a))

    /// Monadic bind for the Desync monad.
    let rec internal bind (m : Desync<'e, 's, 'a>) (cont : 'a -> Desync<'e, 's, 'b>) : Desync<'e, 's, 'b> =
        Desync (fun s ->
            match (match m with Desync f -> f s) with
            //                              ^--- NOTE: unbounded recursion here
            | (s', Left m') -> (s', Left (fun e -> bind (m' e) cont))
            | (s', Right v) -> match cont v with Desync f -> f s')

    /// Builds the Desync monad.
    type DesyncBuilder () =
        member this.Return op = returnM op
        member this.Bind (m, cont) = bind m cont

    /// The Desync builder.
    let desync = DesyncBuilder ()

It allows the implementation of game logic that executes across several game ticks to written in a seemingly sequential style using computation expressions.

Unfortunately, when used for tasks that last for an unbounded number of game ticks, it crashes with StackOverflowException. And even when it's not crashing, it's ending up with unwieldy stack traces like this -

InfinityRpg.exe!InfinityRpg.GameplayDispatcherModule.desync@525-20.Invoke(Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen> _arg10) Line 530   F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>>.Invoke(Nu.SimulationModule.World s) Line 24    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.DesyncModule.bind@20<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>.Invoke(Nu.SimulationModule.World s) Line 21    F#
Prime.exe!Prime.Desync.step<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit>(Prime.DesyncModule.Desync<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit> m, Nu.SimulationModule.World s) Line 71 F#
Prime.exe!Prime.Desync.advanceDesync<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit>(Microsoft.FSharp.Core.FSharpFunc<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Prime.DesyncModule.Desync<Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>,Nu.SimulationModule.World,Microsoft.FSharp.Core.Unit>> m, Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen> e, Nu.SimulationModule.World s) Line 75 F#
Nu.exe!Nu.Desync.advance@98<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>.Invoke(Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen> event, Nu.SimulationModule.World world) Line 100 F#
Nu.exe!Nu.Desync.subscription@104-16<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>.Invoke(Nu.SimulationModule.Event<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen> event, Nu.SimulationModule.World world) Line 105    F#
Nu.exe!Nu.World.boxableSubscription@165<Prime.EitherModule.Either<Microsoft.FSharp.Core.Unit,Microsoft.FSharp.Core.Unit>,Nu.SimulationModule.Screen>.Invoke(object event, Nu.SimulationModule.World world) Line 166 F#

I am hoping to solve the problem by making the Left case of the bind function tail-recursive. However, I'm not sure of two things -

1) if it can be done at all, and 2) how it would actually be done.

If it's impossible to make bind tail-recursive here, is there some way to restructure my monad to allow it to become tail-recursive?

EDIT 3 (subsumes previous edits): Here is additional code that implements the desync combinators I will use to demonstrate the stack overflow -

module Desync =

    /// Get the state.
    let get : Desync<'e, 's, 's> =
        Desync (fun s -> (s, Right s))

    /// Set the state.
    let set s : Desync<'e, 's, unit> =
        Desync (fun _ -> (s, Right ()))

    /// Loop in a desynchronous context while 'pred' evaluate to true.
    let rec loop (i : 'i) (next : 'i -> 'i) (pred : 'i -> 's -> bool) (m : 'i -> Desync<'e, 's, unit>) =
        desync {
            let! s = get
            do! if pred i s then
                    desync {
                        do! m i
                        let i = next i
                        do! loop i next pred m }
                else returnM () }

    /// Loop in a desynchronous context while 'pred' evaluates to true.
    let during (pred : 's -> bool) (m : Desync<'e, 's, unit>) =
        loop () id (fun _ -> pred) (fun _ -> m)

    /// Step once into a desync.
    let step (m : Desync<'e, 's, 'a>) (s : 's) : 's * Either<'e -> Desync<'e, 's, 'a>, 'a> =
    match m with Desync f -> f s

    /// Run a desync to its end, providing e for all its steps.
    let rec runDesync (m : Desync<'e, 's, 'a>) (e : 'e) (s : 's) : ('s * 'a) =
        match step m s with
        | (s', Left m') -> runDesync (m' e) e s'
        | (s', Right v) -> (s', v)

Here is the Either implementation -

[<AutoOpen>]
module EitherModule =

    /// Haskell-style Either type.
    type Either<'l, 'r> =
        | Right of 'r
        | Left of 'l

And finally, here's simple a line of code that will yield a stack overflow -

open Desync
ignore <| runDesync (desync { do! during (fun _ -> true) (returnM ()) }) () ()

回答1:

It seems to me your monad is a State with error handling.

It's basically ErrorT< State<'s,Either<'e,'a>>> but the error branch binds again which is not very clear to me why.

Anyway I was able to reproduce your Stack Overflow with a basic State monad:

type State<'S,'A> = State of ('S->('A * 'S))

module State =
    let run (State x) = x :'s->_
    let get() = State (fun s -> (s , s))  :State<'s,_>
    let put x = State (fun _ -> ((), x))  :State<'s,_>
    let result a = State(fun s -> (a, s))
    let bind (State m) k = State(fun s -> 
                                    let (a, s') = m s
                                    let (State u) = (k a) 
                                    u s')     :State<'s,'b>

    type StateBuilder() =
        member this.Return op = result op
        member this.Bind (m, cont) = bind m cont

    let state = StateBuilder()

    let rec loop (i: 'i) (next: 'i -> 'i) (pred: 'i -> 's -> bool) (m: 'i -> State<'s, unit>) =
        state {
            let! s = get()
            do! if pred i s then
                    state {
                        do! m i
                        let i = next i
                        do! loop i next pred m }
                else result () }

    let during (pred : 's -> bool) (m : State<'s, unit>) =
        loop () id (fun _ -> pred) (fun _ -> m)

// test
open State
ignore <| run (state { do! during (fun c -> true) (result ()) })  () // boom

As stated in the comments one way to solve this is to use a StateT<'s,Cont<'r,'a>>.

Here's an example of the solution. At the end there is a test with the zipIndex function which blows the stack as well when defined with a normal State monad.

Note you don't need to use the Monad Transformers from FsControl, I use them because it's easier for me since I write less code but you can always create your transformed monad by hand.