pyspark; check if an element is in collect_list [d

2020-07-10 09:43发布


I am working on a dataframe df, for instance the following dataframe:


|  aa| apple|
|  bb|orange|
|  bb|  desk|
|  bb|orange|
|  bb|  desk|
|  aa|   pen|
|  bb|pencil|
|  aa| chair|

I use collect_set to aggregate and get a set of objects with duplicate elements eliminated (or collect_list to get list of objects).

df_new = df.groupby('keys').agg(collect_set(df.values).alias('collectedSet_values'))

The resulting dataframe is then as follows:


|keys|collectedSet_values   |
|bb  |[orange, pencil, desk]|
|aa  |[apple, pen, chair]   |

I am struggling to find a way to see if a specific keyword (like 'chair') is in the resulting set of objects (in column collectedSet_values). I do not want to go with udf solution.

Please comment your solutions/ideas.

Kind Regards.


Actually there is a nice function array_contains which does that for us. The way we use it for set of objects is the same as in here. To know if word 'chair' exists in each set of object, we can simply do the following:

df_new.withColumn('contains_chair', array_contains(df_new.collectedSet_values, 'chair')).show()


|keys|collectedSet_values   |contains_chair|
|bb  |[orange, pencil, desk]|false         |
|aa  |[apple, pen, chair]   |true          |

The same applies to the result of collect_list.