So I want my Spark App to read some text from Amazon's S3. I Wrote the following simple script:
import boto3
s3_client = boto3.client('s3')
text_keys = ["key1.txt", "key2.txt"]
data = sc.parallelize(text_keys).flatMap(lambda key: s3_client.get_object(Bucket="my_bucket", Key=key)['Body'].read().decode('utf-8'))
When I do data.collect
I get the following error:
TypeError: can't pickle thread.lock objects
and I don't seem to find any help online. Have perhaps someone managed to solve the above?
Your s3_client isn't serialisable.
Instead of flatMap use mapPartitions, and initialise s3_client inside the lambda body to avoid overhead. That will:
- init s3_client on each worker
- reduce initialisation overhead
Here's how you can use mapPartitions and initialize the s3_client inside the lambda body to avoid overhead.
The motivation for pulling S3 data with the parallelized approach below was inspired by this article: How NOT to pull from S3 using Apache Spark
Note: Credit for the get_matching_s3_objects(..)
method and get_matching_s3_keys(..)
method goes to Alex Chan, here: Listing S3 Keys
There may be a simpler/better way to list the keys and parallelize them, but this is what worked for me.
Also, I strongly recommend that you NOT transmit your AWS_SECRET or AWS_ACCESS_KEY_ID in plain text like in this simplified example. There is good documentation on how to properly secure your code (to access AWS via Boto3) here:
Boto 3 Docs - Configuration and Credentials
First, the imports and string variables:
import boto3
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
AWS_ACCESS_KEY_ID = 'DONT_DO_THIS_ESPECIALLY-IN-PRODUCTION'
AWS_SECRET = 'ALSO_DONT_DO_THIS_ESPECIALLY-IN-PRODUCTION'
bucket_name = 'my-super-s3-bucket-example-name'
appName = 'mySuperAppExample'
Then, the methods from that first link that I mentioned above:
def get_matching_s3_objects(s3, bucket, prefix='', suffix=''):
"""
Generate objects in an S3 bucket.
:param bucket: Name of the S3 bucket.
:param prefix: Only fetch objects whose key starts with
this prefix (optional).
:param suffix: Only fetch objects whose keys end with
this suffix (optional).
"""
kwargs = {'Bucket': bucket}
# If the prefix is a single string (not a tuple of strings), we can
# do the filtering directly in the S3 API.
if isinstance(prefix, str):
kwargs['Prefix'] = prefix
while True:
# The S3 API response is a large blob of metadata.
# 'Contents' contains information about the listed objects.
resp = s3.list_objects_v2(**kwargs)
try:
contents = resp['Contents']
except KeyError:
return
for obj in contents:
key = obj['Key']
if key.startswith(prefix) and key.endswith(suffix):
yield obj
# The S3 API is paginated, returning up to 1000 keys at a time.
# Pass the continuation token into the next response, until we
# reach the final page (when this field is missing).
try:
kwargs['ContinuationToken'] = resp['NextContinuationToken']
except KeyError:
break
def get_matching_s3_keys(s3, bucket, prefix='', suffix=''):
"""
Generate the keys in an S3 bucket.
:param bucket: Name of the S3 bucket.
:param prefix: Only fetch keys that start with this prefix (optional).
:param suffix: Only fetch keys that end with this suffix (optional).
"""
for obj in get_matching_s3_objects(s3, bucket, prefix, suffix):
yield obj['Key']
Then, a method that I wrote to construct a function with a closure that is compatible with .mapPartitions(..)
:
# Again, please don't transmit your keys in plain text.
# I did this here just for the sake of completeness of the example
# so that the code actually works.
def getObjsFromMatchingS3Keys(AWS_ACCESS_KEY_ID, AWS_SECRET, bucket_name):
def getObjs(s3Keys):
for key in s3Keys:
session = boto3.session.Session(AWS_ACCESS_KEY_ID, AWS_SECRET)
s3_client = session.client('s3')
body = s3_client.get_object(Bucket=bucket_name, Key=key)['Body'].read().decode('utf-8')
yield body
return getObjs
Then, setup the SparkContext and get the list of S3 object keys:
conf = SparkConf().setAppName(appName)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)
session = boto3.session.Session(AWS_ACCESS_KEY_ID, AWS_SECRET)
# For the third time, please don't transmit your credentials in plain text like this.
# Hopefully you won't need another warning.
s3_client = session.client('s3')
func = getObjsFromMatchingS3Keys(AWS_ACCESS_KEY_ID, AWS_SECRET, bucket_name)
myFileObjs = []
for fName in get_matching_s3_keys(s3_client, bucket_name):
myFileObjs.append(fName)
Side note: We needed to construct a SparkSession so that .toDF()
would be available to the PipelinedRDD type due to the monkey patch, as explained here:
PipelinedRDD object has no attribute toDF in PySpark
Finally, parallelize the S3 object keys with .mapPartitions(..)
and the function we constructed:
pathToSave = r'absolute_path_to_your_desired_file.json'
sc.parallelize(myFileObjs) \
.mapPartitions(lambda keys: func(keys)) \
.map(lambda x: (x, )) \
.toDF() \
.toPandas() \
.to_json(pathToSave)
There might be a more concise method to write to the target output file, but this code still works. Also, the purpose of the use of map(lambda x: (x, ))
was to force schema inference, as mentioned here: Create Spark DataFrame - Cannot infer schema for type
Forcing schema inference in this way may not be the best approach for all situations, but it was sufficient for this example.