A direct cut and paste of the following algorithm:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys
case (_, Nil) => xs
case (x :: xs1, y :: ys1) =>
if (less(x, y)) x :: merge(xs1, ys)
else y :: merge(xs, ys1)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs))
}
}
causes a StackOverflowError on 5000 long lists.
Is there any way to optimize this so that this doesn't occur?
Just playing around with scala's
TailCalls
(trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.Runs just as fast as the mutable version on big
List[Long]
s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.
Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:
but it keeps track of all the variables for you.
(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)
It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.
The latter solution goes like this:
Using non-strictness involves either passing parameters by-name, or using non-strict collections such as
Stream
. The following code usesStream
just to prevent stack overflow, andList
elsewhere: