Merge sort from “Programming Scala” causes stack o

2019-01-13 16:47发布

A direct cut and paste of the following algorithm:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) => xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

causes a StackOverflowError on 5000 long lists.

Is there any way to optimize this so that this doesn't occur?

3条回答
虎瘦雄心在
2楼-- · 2019-01-13 17:19

Just playing around with scala's TailCalls (trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.

import scala.util.control.TailCalls._

def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {

  def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
    if (a.isEmpty) {
      done(b.reverse ::: s)
    } else if (b.isEmpty) {
      done(a.reverse ::: s)
    } else if (a.head<b.head) {
      tailcall(build(a.head::s,a.tail,b))
    } else {
      tailcall(build(b.head::s,a,b.tail))
    }
  }

  build(List(),x,y).result.reverse
}

Runs just as fast as the mutable version on big List[Long]s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).

查看更多
小情绪 Triste *
3楼-- · 2019-01-13 17:38

Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.

Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:

def merge(xs: List[T], ys: List[T]): List[T] = {
  var acc:List[T] = Nil
  var decx = xs
  var decy = ys
  while (!decx.isEmpty || !decy.isEmpty) {
    (decx, decy) match { 
      case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
      case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) { acc = x :: acc ; decx = xs1 }
        else { acc = y :: acc ; decy = ys1 }
    }
  }
  acc.reverse
}

but it keeps track of all the variables for you.

(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)

查看更多
Luminary・发光体
4楼-- · 2019-01-13 17:40

It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.

The latter solution goes like this:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream. The following code uses Stream just to prevent stack overflow, and List elsewhere:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}
查看更多
登录 后发表回答