Spark's .count() function is different to the

2019-02-24 14:45发布

问题:

I have a Spark job, written in Python, which is getting odd behaviour when checking for errors in its data. A simplified version is below:

from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, StructType, StructField, DoubleType
from pyspark.sql.functions import col, lit

spark = SparkSession.builder.master("local[3]").appName("pyspark-unittest").getOrCreate()
spark.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")


SCHEMA = StructType([
    StructField("headerDouble", DoubleType(), False),
    StructField("ErrorField", StringType(), False)
])

dataframe = (
    spark.read
    .option("header", "true")
    .option("mode", "PERMISSIVE")
    .option("columnNameOfCorruptRecord", "ErrorField")
    .schema(SCHEMA).csv("./x.csv")
)

total_row_count = dataframe.count()
print("total_row_count = " + str(total_row_count))

errors = dataframe.filter(col("ErrorField").isNotNull())
errors.show()
error_count = errors.count()
print("errors count = " + str(error_count))

The csv it is reading is simply:

headerDouble
wrong

The relevant output of this is

total_row_count = 1
+------------+----------+
|headerDouble|ErrorField|
+------------+----------+
|        null|     wrong|
+------------+----------+

errors count = 0

Now how does this possibly happen? If the dataframe has a record, how is being counted as 0? Is this a bug in the Spark infrastructure or am I missing something?

EDIT: Looks like this might be a known bug on Spark 2.2 which has been fixed in Spark 2.3 - https://issues.apache.org/jira/browse/SPARK-21610

回答1:

Thanks @user6910411 - does seem to be a bug. I've raised an issue in the Spark project's bug tracker.

I'm speculating that Spark is getting confused due to the presence of the ErrorField in the schema which is also being specified as the error column and being used to filter the dataframe.

Meanwhile I think I've found a workaround to count the dataframe rows at a reasonable speed:

def count_df_with_spark_bug_workaround(df):
    return sum(1 for _ in df.toLocalIterator())

Not quite sure why this gives the right answer when .count() doesn't work.

Jira ticket I raised: https://issues.apache.org/jira/browse/SPARK-24147

This turned out to be a duplicate of: https://issues.apache.org/jira/browse/SPARK-21610