Newbie Scala question about simple math array oper

2020-06-13 05:25发布

问题:

Newbie Scala Question:

Say I want to do this [Java code] in Scala:

public static double[] abs(double[] r, double[] im) {
  double t[] = new double[r.length];
  for (int i = 0; i < t.length; ++i) {
    t[i] = Math.sqrt(r[i] * r[i] + im[i] * im[i]);
  }
  return t;
}  

and also make it generic (since Scala efficiently do generic primitives I have read). Relying only on the core language (no library objects/classes, methods, etc), how would one do this? Truthfully I don't see how to do it at all, so I guess that's just a pure bonus point question.

I ran into sooo many problems trying to do this simple thing that I have given up on Scala for the moment. Hopefully once I see the Scala way I will have an 'aha' moment.

UPDATE: Discussing this with others, this is the best answer I have found so far.

def abs[T](r: Iterable[T], im: Iterable[T])(implicit n: Numeric[T]) = {
   import n.mkNumericOps                                                   
   r zip(im) map(t => math.sqrt((t._1 * t._1 + t._2 * t._2).toDouble))          
}

回答1:

Doing generic/performant primitives in scala actually involves two related mechanisms which scala uses to avoid boxing/unboxing (e.g. wrapping an int in a java.lang.Integer and vice versa):

  • @specialize type annotations
  • Using Manifest with arrays

specialize is an annotation that tells the Java compiler to create "primitive" versions of code (akin to C++ templates, so I am told). Check out the type declaration of Tuple2 (which is specialized) compared with List (which isn't). It was added in 2.8 and means that, for example code like CC[Int].map(f : Int => Int) is executed without ever boxing any ints (assuming CC is specialized, of course!).

Manifests are a way of doing reified types in scala (which is limited by the JVM's type erasure). This is particularly useful when you want to have a method genericized on some type T and then create an array of T (i.e. T[]) within the method. In Java this is not possible because new T[] is illegal. In scala this is possible using Manifests. In particular, and in this case it allows us to construct a primitive T-array, like double[] or int[]. (This is awesome, in case you were wondering)

Boxing is so important from a performance perspective because it creates garbage, unless all of your ints are < 127. It also, obviously, adds a level of indirection in terms of extra process steps/method calls etc. But consider that you probably don't give a hoot unless you are absolutely positively sure that you definitely do (i.e. most code does not need such micro-optimization)


So, back to the question: in order to do this with no boxing/unboxing, you must use Array (List is not specialized yet, and would be more object-hungry anyway, even if it were!). The zipped function on a pair of collections will return a collection of Tuple2s (which will not require boxing, as this is specialized).

In order to do this generically (i.e. across various numeric types) you must require a context bound on your generic parameter that it is Numeric and that a Manifest can be found (required for array creation). So I started along the lines of...

def abs[T : Numeric : Manifest](rs : Array[T], ims : Array[T]) : Array[T] = {
    import math._
    val num = implicitly[Numeric[T]]
    (rs, ims).zipped.map { (r, i) => sqrt(num.plus(num.times(r,r), num.times(i,i))) }
    //                               ^^^^ no SQRT function for Numeric
}

...but it doesn't quite work. The reason is that a "generic" Numeric value does not have an operation like sqrt -> so you could only do this at the point of knowing you had a Double. For example:

scala> def almostAbs[T : Manifest : Numeric](rs : Array[T], ims : Array[T]) : Array[T] = {
 | import math._
 | val num = implicitly[Numeric[T]]
 | (rs, ims).zipped.map { (r, i) => num.plus(num.times(r,r), num.times(i,i)) }
 | }
almostAbs: [T](rs: Array[T],ims: Array[T])(implicit evidence$1: Manifest[T],implicit     evidence$2: Numeric[T])Array[T]

Excellent - now see this purely generic method do some stuff!

scala> val rs = Array(1.2, 3.4, 5.6); val is = Array(6.5, 4.3, 2.1)
rs: Array[Double] = Array(1.2, 3.4, 5.6)
is: Array[Double] = Array(6.5, 4.3, 2.1)

scala> almostAbs(rs, is)
res0: Array[Double] = Array(43.69, 30.049999999999997, 35.769999999999996)

Now we can sqrt the result, because we have a Array[Double]

scala> res0.map(math.sqrt(_))
res1: Array[Double] = Array(6.609841147864296, 5.481788029466298, 5.980802621722272)

And to prove that this would work even with another Numeric type:

scala> import math._
import math._
scala> val rs = Array(BigDecimal(1.2), BigDecimal(3.4), BigDecimal(5.6)); val is =     Array(BigDecimal(6.5), BigDecimal(4.3), BigDecimal(2.1))
rs: Array[scala.math.BigDecimal] = Array(1.2, 3.4, 5.6)
is: Array[scala.math.BigDecimal] = Array(6.5, 4.3, 2.1)

scala> almostAbs(rs, is)
res6: Array[scala.math.BigDecimal] = Array(43.69, 30.05, 35.77)

scala> res6.map(d => math.sqrt(d.toDouble))
res7: Array[Double] = Array(6.609841147864296, 5.481788029466299, 5.9808026217222725)


回答2:

Use zip and map:

scala> val reals = List(1.0, 2.0, 3.0)
reals: List[Double] = List(1.0, 2.0, 3.0)

scala> val imags = List(1.5, 2.5, 3.5)
imags: List[Double] = List(1.5, 2.5, 3.5)

scala> reals zip imags
res0: List[(Double, Double)] = List((1.0,1.5), (2.0,2.5), (3.0,3.5))

scala> (reals zip imags).map {z => math.sqrt(z._1*z._1 + z._2*z._2)}
res2: List[Double] = List(1.8027756377319946, 3.2015621187164243, 4.6097722286464435)

scala> def abs(reals: List[Double], imags: List[Double]): List[Double] =
     | (reals zip imags).map {z => math.sqrt(z._1*z._1 + z._2*z._2)}
abs: (reals: List[Double],imags: List[Double])List[Double]

scala> abs(reals, imags)
res3: List[Double] = List(1.8027756377319946, 3.2015621187164243, 4.6097722286464435)

UPDATE

It is better to use zipped because it avoids creating a temporary collection:

scala> def abs(reals: List[Double], imags: List[Double]): List[Double] =
     | (reals, imags).zipped.map {(x, y) => math.sqrt(x*x + y*y)}
abs: (reals: List[Double],imags: List[Double])List[Double]

scala> abs(reals, imags)
res7: List[Double] = List(1.8027756377319946, 3.2015621187164243, 4.6097722286464435)


回答3:

There isn't a easy way in Java to create generic numeric computational code; the libraries aren't there as you can see from oxbow's answer. Collections also are designed to take arbitrary types, which means that there's an overhead in working with primitives with them. So the fastest code (without careful bounds checking) is either:

def abs(re: Array[Double], im: Array[Double]) = {
  val a = new Array[Double](re.length)
  var i = 0
  while (i < a.length) {
    a(i) = math.sqrt(re(i)*re(i) + im(i)*im(i))
    i += 1
  }
  a
}

or, tail-recursively:

def abs(re: Array[Double], im: Array[Double]) = {
  def recurse(a: Array[Double], i: Int = 0): Array[Double] = {
    if (i < a.length) {
      a(i) = math.sqrt(re(i)*re(i) + im(i)*im(i))
      recurse(a, i+1)
    }
    else a
  }
  recurse(new Array[Double](re.length))
}

So, unfortunately, this code ends up not looking super-nice; the niceness comes once you package it in a handy complex number array library.

If it turns out that you don't actually need highly efficient code, then

def abs(re: Array[Double], im: Array[Double]) = {
  (re,im).zipped.map((i,j) => math.sqrt(i*i + j*j))
}

will do the trick compactly and conceptually clearly (once you understand how zipped works). The penalty in my hands is that this is about 2x slower. (Using List makes it 7x slower than while or tail recursion in my hands; List with zip makes it 20x slower; generics with arrays are 3x slower even without computing the square root.)

(Edit: fixed timings to reflect a more typical use case.)



回答4:

After Edit:

OK I have got running what I wanted to do. Will take two Lists of any type of number and return an Array of Doubles.

def abs[A](r:List[A], im:List[A])(implicit numeric: Numeric[A]):Array[Double] = {
  var t = new Array[Double](r.length)
  for( i <- r.indices) {          
    t(i) = math.sqrt(numeric.toDouble(r(i))*numeric.toDouble(r(i))+numeric.toDouble(im(i))*numeric.toDouble(im(i)))
  }
  t
}


标签: scala