Scala Pattern Matching with Sets

2020-08-09 11:45发布

问题:

The following doesn't work.

object Foo {
    def union(s: Set[Int], t: Set[Int]): Set[Int] = t match {
        case isEmpty => s
        case (x:xs)  => union(s + x, xs)
        case _       => throw new Error("bad input")
    }
}

error: not found: type xs

How can I pattern match over a set?

回答1:

Well, x:xs means x of type xs, so it wouldn't work. But, alas, you can't pattern match sets, because sets do not have a defined order. Or, more pragmatically, because there's no extractor on Set.

You can always define your own, though:

object SetExtractor {
  def unapplySeq[T](s: Set[T]): Option[Seq[T]] = Some(s.toSeq)
}

For example:

scala> Set(1, 2, 3) match {
     |   case SetExtractor(x, xs @ _*) => println(s"x: $x\nxs: $xs")
     | }
x: 1
xs: ArrayBuffer(2, 3)


回答2:

Set is not a case class and doesn't have a unapply method.

These two things imply that you cannot pattern match directly on a Set.
(update: unless you define your own extractor for Set, as Daniel correctly shows in his answer)

You should find an alternative, I'd suggest using a fold function

def union(s: Set[Int], t: Set[Int]): Set[Int] = 
    (s foldLeft t) {case (t: Set[Int], x: Int) => t + x}

or, avoiding most explicit type annotation

def union(s: Set[Int], t: Set[Int]): Set[Int] =
  (s foldLeft t)( (union, element) => union + element )

or even shorter

def union(s: Set[Int], t: Set[Int]): Set[Int] =
  (s foldLeft t)(_ + _)

This will accumulate the elements of s over t, adding them one by one


folding

Here are the docs for the fold operation, if needed for reference:

foldLeft[B](z: B)(op: (B, A) ⇒ B): B

Applies a binary operator to a start value and all elements of this set, going left to right.

Note: might return different results for different runs, unless the underlying collection type is ordered. or the operator is associative and commutative.

B the result type of the binary operator.
z the start value.
op the binary operator.
returns the result of inserting op between consecutive elements of this set, going left to right with the start value z on the left:

op(...op(z, x_1), x_2, ..., x_n)
where x1, ..., xn are the elements of this set.


回答3:

First of all, your isEmpty will catch every Set since it's a variable in this context. Constants start with an upper case letter in Scala and are treated only as constants if this condition holds. So lowercase will assign any Set to isEmpty (were you looking for EmptySet?)

As seen here, it seems that pattern matching isn't very preferable for Sets. You should probably explicitly convert the Set to a List or Seq (toList / toSeq)

object Foo {
    def union(s: Set[Int], t: Set[Int]): Set[Int] = t.toList match {
        case Nil => s
        case (x::xs)  => union(s + x, xs.toSet)
        case _       => throw new Error("bad input")
    }
}


回答4:

This is what I can come up with:

object Contains {
  class Unapplier[T](val t: T) {
    def unapply(s: Set[T]): Option[Boolean] = Some(s contains t)
  }
  def apply[T](t: T) = new Unapplier(t)
}

object SET {
  class Unapplier[T](val set: Set[T]) {
    def unapply(s: Set[T]): Option[Unit] = if (set == s) Some(Unit) else None
  }
  def apply[T](ts: T*) = new Unapplier(ts.toSet)
}

val Contains2 = Contains(2)
val SET123 = SET(1, 2, 3)

Set(1, 2, 3) match {
  case SET123()         => println("123")
  case Contains2(true)  => println("jippy")
  case Contains2(false) => println("ohh noo")
}


回答5:

    t match {
      case s if s.nonEmpty => // non-empty 
      case _ => // empty
    }