A direct cut and paste of the following algorithm:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys
case (_, Nil) => xs
case (x :: xs1, y :: ys1) =>
if (less(x, y)) x :: merge(xs1, ys)
else y :: merge(xs, ys1)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs))
}
}
causes a StackOverflowError on 5000 long lists.
Is there any way to optimize this so that this doesn't occur?
It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.
The latter solution goes like this:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys.reverse ::: acc
case (_, Nil) => xs.reverse ::: acc
case (x :: xs1, y :: ys1) =>
if (less(x, y)) merge(xs1, ys, x :: acc)
else merge(xs, ys1, y :: acc)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs), Nil).reverse
}
}
Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream
. The following code uses Stream
just to prevent stack overflow, and List
elsewhere:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
case _ => if (left.isEmpty) right.toStream else left.toStream
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs)).toList
}
}
Just playing around with scala's TailCalls
(trampolining support), which I suspect wasn't around when this question was originally posed. Here's a recursive immutable version of the merge in Rex's answer.
import scala.util.control.TailCalls._
def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = {
def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = {
if (a.isEmpty) {
done(b.reverse ::: s)
} else if (b.isEmpty) {
done(a.reverse ::: s)
} else if (a.head<b.head) {
tailcall(build(a.head::s,a.tail,b))
} else {
tailcall(build(b.head::s,a,b.tail))
}
}
build(List(),x,y).result.reverse
}
Runs just as fast as the mutable version on big List[Long]
s on Scala 2.9.1 on 64bit OpenJDK (Debian/Squeeze amd64 on an i7).
Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.
Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:
def merge(xs: List[T], ys: List[T]): List[T] = {
var acc:List[T] = Nil
var decx = xs
var decy = ys
while (!decx.isEmpty || !decy.isEmpty) {
(decx, decy) match {
case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
case (x :: xs1, y :: ys1) =>
if (less(x, y)) { acc = x :: acc ; decx = xs1 }
else { acc = y :: acc ; decy = ys1 }
}
}
acc.reverse
}
but it keeps track of all the variables for you.
(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)