PySpark replace Null with Array

2019-05-11 00:25发布

问题:

After a join by ID, my data frame looks as follows:

ID  |  Features  |  Vector
1   | (50,[...]  | Array[1.1,2.3,...]
2   | (50,[...]  | Null

I ended up with Null values for some IDs in the column 'Vector'. I would like to replace these Null values by an array of zeros with 300 dimensions (same format as non-null vector entries). df.fillna does not work here since it's an array I would like to insert. Any idea how to accomplish this in PySpark?

---edit---

Similarly to this post my current approach:

df_joined = id_feat_vec.join(new_vec_df, "id", how="left_outer")

fill_with_vector = udf(lambda x: x if x is not None else np.zeros(300),
                                 ArrayType(DoubleType()))

df_new = df_joined.withColumn("vector", fill_with_vector("vector"))

Unfortunately with little success:

org.apache.spark.SparkException: Job aborted due to stage failure: Task 0in stage 848.0 failed 4 times, most recent failure: Lost task 0.3 in stage 848.0 (TID 692199, 10.179.224.107, executor 16): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.core.multiarray._reconstruct)
---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-193-e55fed27fcd8> in <module>()
      5 a = df_joined.withColumn("vector", fill_with_vector("vector"))
      6 
----> 7 a.show()

/databricks/spark/python/pyspark/sql/dataframe.pyc in show(self, n, truncate)
    316         """
    317         if isinstance(truncate, bool) and truncate:
--> 318             print(self._jdf.showString(n, 20))
    319         else:
    320             print(self._jdf.showString(n, int(truncate)))

/databricks/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1131         answer = self.gateway_client.send_command(command)
   1132         return_value = get_return_value(
-> 1133             answer, self.gateway_client, self.target_id, self.name)
   1134 
   1135         for temp_arg in temp_args:

回答1:

Updated: I couldn't get the SQL expression form to create an array of doubles. 'array(0.0, ...)' appears to create an array of Decimal types. But, using the python functions you can get it to properly create an array of doubles.

The general idea is use the when/otherwise functions to selectively update only the rows you want. You can define the literal value you want ahead of time as a column and then dump that in the "THEN" clause.

from pyspark.sql.types import *
from pyspark.sql.functions import *

schema = StructType([StructField("f1", LongType()), StructField("f2", ArrayType(DoubleType(), False))])
data = [(1, [10.0, 11.0]), (2, None), (3, None)]

df = sqlContext.createDataFrame(sc.parallelize(data), schema)

# Create a column object storing the value you want in the NULL case
num_elements = 300
null_value = array([lit(0.0)] * num_elements)

# If you want a different type you can change it like this
# null_value = null_value.cast('array<float>')

# Keep the value when there is one, replace it when it's null
df2 = df.withColumn('f2', when(df['f2'].isNull(), null_value).otherwise(df['f2']))


回答2:

You could try to make an update request on your dataset with a where, replacing every NULL in the Vector column by an array. Are you using SparkSQL and dataframes?