What are Scala continuations and why use them?

2020-01-27 08:58发布

I just finished Programming in Scala, and I've been looking into the changes between Scala 2.7 and 2.8. The one that seems to be the most important is the continuations plugin, but I don't understand what it's useful for or how it works. I've seen that it's good for asynchronous I/O, but I haven't been able to find out why. Some of the more popular resources on the subject are these:

And this question on Stack Overflow:

Unfortunately, none of these references try to define what continuations are for or what the shift/reset functions are supposed to do, and I haven't found any references that do. I haven't been able to guess how any of the examples in the linked articles work (or what they do), so one way to help me out could be to go line-by-line through one of those samples. Even this simple one from the third article:

reset {
    ...
    shift { k: (Int=>Int) =>  // The continuation k will be the '_ + 1' below.
        k(7)
    } + 1
}
// Result: 8

Why is the result 8? That would probably help me to get started.

7条回答
The star\"
2楼-- · 2020-01-27 10:00

Scala Continuations via Meaningful Examples

Let us define from0to10 that expresses the idea of iteration from 0 to 10:

def from0to10() = shift { (cont: Int => Unit) =>
   for ( i <- 0 to 10 ) {
     cont(i)
   }
}

Now,

reset {
  val x = from0to10()
  print(s"$x ")
}
println()

prints:

0 1 2 3 4 5 6 7 8 9 10 

In fact, we do not need x:

reset {
  print(s"${from0to10()} ")
}
println()

prints the same result.

And

reset {
  print(s"(${from0to10()},${from0to10()}) ")
}
println()

prints all pairs:

(0,0) (0,1) (0,2) (0,3) (0,4) (0,5) (0,6) (0,7) (0,8) (0,9) (0,10) (1,0) (1,1) (1,2) (1,3) (1,4) (1,5) (1,6) (1,7) (1,8) (1,9) (1,10) (2,0) (2,1) (2,2) (2,3) (2,4) (2,5) (2,6) (2,7) (2,8) (2,9) (2,10) (3,0) (3,1) (3,2) (3,3) (3,4) (3,5) (3,6) (3,7) (3,8) (3,9) (3,10) (4,0) (4,1) (4,2) (4,3) (4,4) (4,5) (4,6) (4,7) (4,8) (4,9) (4,10) (5,0) (5,1) (5,2) (5,3) (5,4) (5,5) (5,6) (5,7) (5,8) (5,9) (5,10) (6,0) (6,1) (6,2) (6,3) (6,4) (6,5) (6,6) (6,7) (6,8) (6,9) (6,10) (7,0) (7,1) (7,2) (7,3) (7,4) (7,5) (7,6) (7,7) (7,8) (7,9) (7,10) (8,0) (8,1) (8,2) (8,3) (8,4) (8,5) (8,6) (8,7) (8,8) (8,9) (8,10) (9,0) (9,1) (9,2) (9,3) (9,4) (9,5) (9,6) (9,7) (9,8) (9,9) (9,10) (10,0) (10,1) (10,2) (10,3) (10,4) (10,5) (10,6) (10,7) (10,8) (10,9) (10,10) 

Now, how does that work?

There is the called code, from0to10, and the calling code. In this case, it is the block that follows reset. One of the parameters passed to the called code is a return address that shows what part of the calling code has not yet been executed (**). That part of the calling code is the continuation. The called code can do with that parameter whatever it decides to: pass control to it, or ignore, or call it multiple times. Here from0to10 calls that continuation for each integer in the range 0..10.

def from0to10() = shift { (cont: Int => Unit) =>
   for ( i <- 0 to 10 ) {
     cont(i) // call the continuation
   }
}

But where does the continuation end? This is important because the last return from the continuation returns control to the called code, from0to10. In Scala, it ends where the reset block ends (*).

Now, we see that the continuation is declared as cont: Int => Unit. Why? We invoke from0to10 as val x = from0to10(), and Int is the type of value that goes to x. Unit means that the block after reset must return no value (otherwise there will be a type error). In general, there are 4 type signatures: function input, continuation input, continuation result, function result. All four must match the invocation context.

Above, we printed pairs of values. Let us print the multiplication table. But how do we output \n after each row?

The function back lets us specify what must be done when control returns back, from the continuation to the code that called it.

def back(action: => Unit) = shift { (cont: Unit => Unit) =>
  cont()
  action
}

back first calls its continuation, and then performs the action.

reset {
  val i = from0to10()
  back { println() }
  val j = from0to10
  print(f"${i*j}%4d ") // printf-like formatted i*j
}

It prints:

   0    0    0    0    0    0    0    0    0    0    0 
   0    1    2    3    4    5    6    7    8    9   10 
   0    2    4    6    8   10   12   14   16   18   20 
   0    3    6    9   12   15   18   21   24   27   30 
   0    4    8   12   16   20   24   28   32   36   40 
   0    5   10   15   20   25   30   35   40   45   50 
   0    6   12   18   24   30   36   42   48   54   60 
   0    7   14   21   28   35   42   49   56   63   70 
   0    8   16   24   32   40   48   56   64   72   80 
   0    9   18   27   36   45   54   63   72   81   90 
   0   10   20   30   40   50   60   70   80   90  100 

Well, now it's time for some brain-twisters. There are two invocations of from0to10. What is the continuation for the first from0to10? It follows the invocation of from0to10 in the binary code, but in the source code it also includes the assignment statement val i =. It ends where the reset block ends, but the end of the reset block does not return control to the first from0to10. The end of the reset block returns control to the 2nd from0to10, that in turn eventually returns control to back, and it is back that returns control to the first invocation of from0to10. When the first (yes! 1st!) from0to10 exits, the whole reset block is exited.

Such method of returning control back is called backtracking, it is a very old technique, known at least from the times of Prolog and AI-oriented Lisp derivatives.

The names reset and shift are misnomers. These names should better have been left for the bitwise operations. reset defines continuation boundaries, and shift takes a continuation from the call stack.

Note(s)

(*) In Scala, the continuation ends where the reset block ends. Another possible approach would be to let it end where the function ends.

(**) One of the parameters of the called code is a return address that shows what part of the calling code has not yet been executed. Well, in Scala, a sequence of return addresses is used for that. How many? All of the return addresses placed on the call stack since entering the reset block.


UPD Part 2 Discarding Continuations: Filtering

def onEven(x:Int) = shift { (cont: Unit => Unit) =>
  if ((x&1)==0) {
    cont() // call continuation only for even numbers
  }
}
reset {
  back { println() }
  val x = from0to10()
  onEven(x)
  print(s"$x ")
}

This prints:

0 2 4 6 8 10 

Let us factor out two important operations: discarding the continuation (fail()) and passing control on to it (succ()):

// fail: just discard the continuation, force control to return back
def fail() = shift { (cont: Unit => Unit) => }
// succ: does nothing (well, passes control to the continuation), but has a funny signature
def succ():Unit @cpsParam[Unit,Unit] = { }
// def succ() = shift { (cont: Unit => Unit) => cont() }

Both versions of succ() (above) work. It turns out that shift has a funny signature, and although succ() does nothing, it must have that signature for type balance.

reset {
  back { println() }
  val x = from0to10()
  if ((x&1)==0) {
    succ()
  } else {
    fail()
  }
  print(s"$x ")
}

as expected, it prints

0 2 4 6 8 10

Within a function, succ() is not necessary:

def onTrue(b:Boolean) = {
  if(!b) {
    fail()
  }
}
reset {
  back { println() }
  val x = from0to10()
  onTrue ((x&1)==0)
  print(s"$x ")
}

again, it prints

0 2 4 6 8 10

Now, let us define onOdd() via onEven():

// negation: the hard way
class ControlTransferException extends Exception {}
def onOdd(x:Int) = shift { (cont: Unit => Unit) =>
  try {
    reset {
      onEven(x)
      throw new ControlTransferException() // return is not allowed here
    }
    cont()
  } catch {
    case e: ControlTransferException =>
    case t: Throwable => throw t
  }
}
reset {
  back { println() }
  val x = from0to10()
  onOdd(x)
  print(s"$x ")
}

Above, if x is even, an exception is thrown and the continuation is not called; if x is odd, the exception is not thrown and the continuation is called. The above code prints:

1 3 5 7 9 
查看更多
登录 后发表回答