How to model recursive function types?

2020-07-17 06:29发布

问题:

I'm curious how this code would be modelled in Scala.

This is in golang, and it is a recursive function type:

type walkFn func(*int) walkFn

So the above is just a defintion of the type, a walk function is a function that takes a pointer to an integer, and returns a walk function.

An example implementation would be:

func walkForward(i *int) walkFn {
    *i += rand.Intn(6)
    return pickRandom(walkEqual, walkBackward)
}

func walkBackward(i *int) walkFn {
    *i += -rand.Intn(6)
    return pickRandom(walkEqual, walkForward)
}

You can run code like this here: http://play.golang.org/p/621lCnySmy

Is it possible to write something like this pattern in Scala?

回答1:

It's possible. You can use existential types to "cheat" scala's cyclic reference restriction:

type A[T <: A[_]] = Int => (Int, T)

lazy val walkEqual: A[A[_]] = (i: Int) => 
  (i + Random.nextInt(7) - 3, if (Random.nextBoolean) walkForward else walkBackward)

lazy val walkForward: A[A[_]] = (i: Int) =>  
  (i + Random.nextInt(6), if (Random.nextBoolean) walkEqual else walkBackward)

lazy val walkBackward: A[A[_]] = (i: Int) => 
  (i - Random.nextInt(6), if (Random.nextBoolean) walkEqual else walkForward)

def doWalk(count: Int, walkFn: A[_] = walkEqual, progress: Int = 0): Unit =
  if (count > 0) {
    val (nextProgress, nextStep: A[_] @unchecked) = walkFn(progress)
    println(nextProgress)
    doWalk(count - 1, nextStep, nextProgress)
  }

Result:

scala> doWalk(10)
2
5
2
0
-3
-5
-4
-8
-8
-11

Or like in @Travis Brown addition:

val locations = Stream.iterate[(Int,A[_] @unchecked)](walkEqual(0)) {
   case (x: Int, f: A[_]) => f(x)
}.map(_._1)

scala> locations.take(20).toList
res151: List[Int] = List(-1, 1, 1, 4, 1, -2, 0, 1, 0, 1, 4, -1, -2, -4, -2, -1, 2, 1, -1, -2)


回答2:

One of the sad facts of Scala is that recursive type aliases are not supported. Doing the following in the REPL yields the following:

scala> type WalkFn = Function[Int, WalkFn]
<console>:7: error: illegal cyclic reference involving type WalkFn
       type WalkFn = Function[Int, WalkFn]
                                   ^

Another note is that Scala does not allow you to modify values by reference (generally frowned upon, nay, loathed entirely in the functional-programming-paradigm).

However, don't dismay! There are other options. Traits can be self-referential, and functions are simply classes in Scala. So we can model the generic recursive WalkFn with traits. Also, we can embrace immutable values and have our function return the next progress, rather than mutate a parameter by reference.

Since the following contains cyclical references (WalkForward -> WalkBackward, WalkBackward -> WalkForward, etc.), you'll need to type :paste into the scala REPL prior to running the following example (so the Scala compiler will compile all 3 Walk{Forward,Backward,Equal} implementations in one step.

First:

$ scala
Welcome to Scala version 2.11.1 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_05).
Type in expressions to have them evaluated.
Type :help for more information.

scala> :paste
// Entering paste mode (ctrl-D to finish)

Now, paste the code:

import scala.util.Random

object Helpers {
  def pickRandom[A](items: A*) =
    items(Random.nextInt(items.length))
}

trait WalkFn extends (Int => (Int, WalkFn)) {}

object WalkForward extends WalkFn {
  def apply(i: Int) =
    ( i + Random.nextInt(6),
      Helpers.pickRandom(WalkEqual, WalkBackward) )
}

object WalkEqual extends WalkFn {
  def apply(i: Int) =
    ( i + (Random.nextInt(7) - 3),
      Helpers.pickRandom(WalkForward, WalkBackward) )
}

object WalkBackward extends WalkFn {
  def apply(i: Int) =
    ( Random.nextInt(6) - 3,
      Helpers.pickRandom(WalkEqual, WalkForward) )
}

def doWalk(count: Int, walkFn: WalkFn = WalkEqual, progress: Int = 0): Unit =
  if (count > 0) {
    val (nextProgress, nextStep) = walkFn(progress)
    println(nextProgress)
    doWalk(count - 1, nextStep, nextProgress)
  }

doWalk(20)

Then, per instructions, hit ctrl-D.

Enjoy the functional drunken stagger!



回答3:

I'd say it's more idiomatic in Scala to factor out the iteration part. So for example we can define a state machine:

import scala.util.Random

sealed trait Walker {
  def i: Int
  def advance: Walker
}

case class WalkEqual(i: Int) extends Walker {
  def advance = {
    val next = i + Random.nextInt(7) - 3
    if (Random.nextBoolean) WalkForward(next) else WalkBackward(next)
  }
}

case class WalkForward(i: Int) extends Walker {
  def advance = {
    val next = i + Random.nextInt(6)
    if (Random.nextBoolean) WalkEqual(next) else WalkBackward(next)
  }
}

case class WalkBackward(i: Int) extends Walker {
  def advance = {
    val next = i - Random.nextInt(6)
    if (Random.nextBoolean) WalkEqual(next) else WalkForward(next)
  }
}

And then we can write the following:

val locations = Stream.iterate[Walker](WalkEqual(0))(_.advance).map(_.i)

This is an infinite stream of locations that our walker visits. We can use it like this:

scala> locations.take(10).foreach(println)
0
0
-1
2
1
0
-5
-5
-10
-6

We could also take a finite number of these and collect them in a concretely-realized collection (such as a list) by writing locations.take(100).toList.



标签: scala