How to update pyspark dataframe metadata on Spark

2019-08-18 21:15发布

问题:

I'm facing an issue with the OneHotEncoder of SparkML since it reads dataframe metadata in order to determine the value range it should assign for the sparse vector object its creating.

More specifically, I'm encoding a "hour" field using a training set containing all individual values between 0 and 23.

Now I'm scoring a single row data frame using the "transform" method od the Pipeline.

Unfortunately, this leads to a differently encoded sparse vector object for the OneHotEncoder

(24,[5],[1.0]) vs. (11,[10],[1.0])

I've documented this here, but this was identified as duplicate. So in this thread there is a solution posted to update the dataframes's metadata to reflect the real range of the "hour" field:

from pyspark.sql.functions import col

meta = {"ml_attr": {
    "vals": [str(x) for x in range(6)],   # Provide a set of levels
    "type": "nominal", 
    "name": "class"}}

loaded.transform(
    df.withColumn("class", col("class").alias("class", metadata=meta)) )

Unfortunalely I get this error:

TypeError: alias() got an unexpected keyword argument 'metadata'

回答1:

In PySpark 2.1, the alias method has no argument metadata (docs) - this became available in Spark 2.2; nevertheless, it is still possible to modify column metadata in PySpark < 2.2, thanks to the incredible Spark Gotchas, maintained by @eliasah and @zero323:

import json

from pyspark import SparkContext
from pyspark.sql import Column
from pyspark.sql.functions import col

spark.version
# u'2.1.1'

df = sc.parallelize((
        (0, "x", 2.0),
        (1, "y", 3.0),
        (2, "x", -1.0)
        )).toDF(["label", "x1", "x2"])

df.show()
# +-----+---+----+ 
# |label| x1|  x2|
# +-----+---+----+
# |    0|  x| 2.0|
# |    1|  y| 3.0|
# |    2|  x|-1.0|
# +-----+---+----+

Supposing that we want to enforce the possibility of our label data to be between 0 and 5, despite that in our dataframe are between 0 and 2, here is how we should modify the column metadata:

def withMeta(self, alias, meta):
    sc = SparkContext._active_spark_context
    jmeta = sc._gateway.jvm.org.apache.spark.sql.types.Metadata
    return Column(getattr(self._jc, "as")(alias, jmeta.fromJson(json.dumps(meta))))

Column.withMeta = withMeta

# new metadata:
meta = {"ml_attr": {"name": "label_with_meta",
                    "type": "nominal",
                    "vals": [str(x) for x in range(6)]}}

df_with_meta = df.withColumn("label_with_meta", col("label").withMeta("", meta))

Kudos also to this answer by zero323!