Filter an array column based on a provided list

2019-08-04 11:09发布

问题:

I have the following types in a dataframe:

 root
 |-- id: string (nullable = true)
 |-- items: array (nullable = true)
 |    |-- element: string (containsNull = true)

input:

val rawData = Seq(("id1",Array("item1","item2","item3","item4")),("id2",Array("item1","item2","item3")))
val data = spark.createDataFrame(rawData)

and a list of items:

 val filter_list = List("item1", "item2")

I would like to filter out items that are non in the filter_list, similar to how array_contains would function, but its not working on a provided list of strings, only a single value.

so the output would look like this:

val rawData = Seq(("id1",Array("item1","item2")),("id2",Array("item1","item2")))
val data = spark.createDataFrame(rawData)

I tried solving this with the following UDF, but I probably mix types between Scala and Spark:

def filterItems(flist: List[String]) = udf {
  (recs: List[String]) => recs.filter(item => flist.contains(item))
}

I'm using Spark 2.2

thanks!

回答1:

You code is almost right. All you have to do is replace List with Seq

def filterItems(flist: List[String]) = udf {
  (recs: Seq[String]) => recs.filter(item => flist.contains(item))
}

It would also make sense to change signature from List[String] => UserDefinedFunction to SeqString] => UserDefinedFunction, but it is not required.

Reference SQL Programming Guide - Data Types.