I have a question regarding the usage of local variables in closures when accessing Spark RDDs. The problem I would like to solve looks as follows:
I have a list of textfiles that should be read into an RDD.
However, first I need to add additional information to an RDD that is created from a single textfile. This additional information is extracted from the filename. Then, the RDDs are put into one big RDD using union().
from pyspark import SparkConf, SparkContext
spark_conf = SparkConf().setAppName("SparkTest")
spark_context = SparkContext(conf=spark_conf)
list_of_filenames = ['file_from_Ernie.txt', 'file_from_Bert.txt']
rdd_list = []
for filename in list_of_filenames:
tmp_rdd = spark_context.textFile(filename)
# extract_file_info('file_from_Owner.txt') == 'Owner'
file_owner = extract_file_info(filename)
tmp_rdd = tmp_rdd.map(lambda x : (x, file_owner))
rdd_list.append(tmp_rdd)
overall_content_rdd = spark_context.union(rdd_list)
# ...do something...
overall_content_rdd.collect()
# However, this does not work:
# The result is that always Bert will be the owner, i.e., never Ernie.
The problem is that the map() function within the loop does not refer to the “correct” file_owner. Instead, it will refer to the latest value of file_owner. On my local machine, I managed to fix the problem by calling the cache() function for each single RDD:
# ..
tmp_rdd = tmp_rdd.map(lambda x : (x, file_owner))
tmp_rdd.cache()
# ..
My Question: Is using cache() the correct solution for my problem? Are there any alternatives?
Many Thanks!
So the cache() method that you are doing won't necessarily work 100% of the time, it works provided that no nodes fail and no partitions need to be recomputed. A simple solution would be to make a function that will "capture" the value of file_owner. Here is a quick little illustration in the pyspark shell of a potential solution:
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version 1.2.0-SNAPSHOT
/_/
Using Python version 2.7.6 (default, Mar 22 2014 22:59:56)
SparkContext available as sc.
>>> hi = "hi"
>>> sc.parallelize(["panda"])
ParallelCollectionRDD[0] at parallelize at PythonRDD.scala:365
>>> r = sc.parallelize(["panda"])
>>> meeps = r.map(lambda x : x + hi)
>>> hi = "by"
>>> meeps.collect()
['pandaby']
>>> hi = "hi"
>>> def makeGreetFunction(param):
... return (lambda x: x + param)
...
>>> f = makeGreetFunction(hi)
>>> hi="by"
>>> meeps = r.map(f)
>>> meeps.collect()
['pandahi']
>>>
This is not a Spark phenomenon, but a plain Python one.
>>> fns = []
>>> for i in range(3):
... fns.append(lambda: i)
...
>>> for fn in fns:
... print fn()
...
2
2
2
One way to avoid it is to declare functions which default arguments. The default value is evaluated at the time of declaration.
>>> fns = []
>>> for i in range(3):
... def f(i=i):
... return i
... fns.append(f)
...
>>> for fn in fns:
... print fn()
...
0
1
2
This comes up a lot, see these other questions:
- Lexical closures in Python
- What do (lambda) function closures capture?
You could make an array of file owners and use that in the map transformation:
file_owner[i] = extract_file_info(filename)
tmp_rdd = tmp_rdd.map(lambda x : (x, file_owner[i]))
As others explained, the problem with your lambda function is that it will evaluate file_owner
at the time of execution. To force its evaluation during the iteration of your for loop, you have to create and execute construction function. Here is how to do it with lambdas:
# ...
file_owner = extract_file_info(filename)
tmp_rdd = tmp_rdd.map((lambda owner: lambda line: (line,owner))(file_owner))
# ...