Apache Spark reads for S3: can't pickle thread

2020-02-10 07:10发布

问题:

So I want my Spark App to read some text from Amazon's S3. I Wrote the following simple script:

import boto3
s3_client = boto3.client('s3')
text_keys = ["key1.txt", "key2.txt"]
data = sc.parallelize(text_keys).flatMap(lambda key: s3_client.get_object(Bucket="my_bucket", Key=key)['Body'].read().decode('utf-8'))

When I do data.collect I get the following error:

TypeError: can't pickle thread.lock objects

and I don't seem to find any help online. Have perhaps someone managed to solve the above?

回答1:

Your s3_client isn't serialisable.

Instead of flatMap use mapPartitions, and initialise s3_client inside the lambda body to avoid overhead. That will:

  1. init s3_client on each worker
  2. reduce initialisation overhead


回答2:

Here's how you can use mapPartitions and initialize the s3_client inside the lambda body to avoid overhead.

The motivation for pulling S3 data with the parallelized approach below was inspired by this article: How NOT to pull from S3 using Apache Spark

Note: Credit for the get_matching_s3_objects(..) method and get_matching_s3_keys(..) method goes to Alex Chan, here: Listing S3 Keys

There may be a simpler/better way to list the keys and parallelize them, but this is what worked for me. Also, I strongly recommend that you NOT transmit your AWS_SECRET or AWS_ACCESS_KEY_ID in plain text like in this simplified example. There is good documentation on how to properly secure your code (to access AWS via Boto3) here: Boto 3 Docs - Configuration and Credentials

First, the imports and string variables:

import boto3
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

AWS_ACCESS_KEY_ID = 'DONT_DO_THIS_ESPECIALLY-IN-PRODUCTION'
AWS_SECRET = 'ALSO_DONT_DO_THIS_ESPECIALLY-IN-PRODUCTION'
bucket_name = 'my-super-s3-bucket-example-name'
appName = 'mySuperAppExample'

Then, the methods from that first link that I mentioned above:

def get_matching_s3_objects(s3, bucket, prefix='', suffix=''):
    """
    Generate objects in an S3 bucket.

    :param bucket: Name of the S3 bucket.
    :param prefix: Only fetch objects whose key starts with
        this prefix (optional).
    :param suffix: Only fetch objects whose keys end with
        this suffix (optional).
    """
    kwargs = {'Bucket': bucket}

    # If the prefix is a single string (not a tuple of strings), we can
    # do the filtering directly in the S3 API.
    if isinstance(prefix, str):
        kwargs['Prefix'] = prefix

    while True:

        # The S3 API response is a large blob of metadata.
        # 'Contents' contains information about the listed objects.
        resp = s3.list_objects_v2(**kwargs)

        try:
            contents = resp['Contents']
        except KeyError:
            return

        for obj in contents:
            key = obj['Key']
            if key.startswith(prefix) and key.endswith(suffix):
                yield obj

        # The S3 API is paginated, returning up to 1000 keys at a time.
        # Pass the continuation token into the next response, until we
        # reach the final page (when this field is missing).
        try:
            kwargs['ContinuationToken'] = resp['NextContinuationToken']
        except KeyError:
            break


def get_matching_s3_keys(s3, bucket, prefix='', suffix=''):
    """
    Generate the keys in an S3 bucket.

    :param bucket: Name of the S3 bucket.
    :param prefix: Only fetch keys that start with this prefix (optional).
    :param suffix: Only fetch keys that end with this suffix (optional).
    """
    for obj in get_matching_s3_objects(s3, bucket, prefix, suffix):
        yield obj['Key']

Then, a method that I wrote to construct a function with a closure that is compatible with .mapPartitions(..):

# Again, please don't transmit your keys in plain text. 
#   I did this here just for the sake of completeness of the example 
#   so that the code actually works.

def getObjsFromMatchingS3Keys(AWS_ACCESS_KEY_ID, AWS_SECRET, bucket_name):
    def getObjs(s3Keys):
        for key in s3Keys:
            session = boto3.session.Session(AWS_ACCESS_KEY_ID, AWS_SECRET)
            s3_client = session.client('s3')
            body = s3_client.get_object(Bucket=bucket_name, Key=key)['Body'].read().decode('utf-8')
            yield body
    return getObjs

Then, setup the SparkContext and get the list of S3 object keys:

conf = SparkConf().setAppName(appName)
sc = SparkContext(conf=conf)
spark = SparkSession(sc)
session = boto3.session.Session(AWS_ACCESS_KEY_ID, AWS_SECRET) 
# For the third time, please don't transmit your credentials in plain text like this. 
#  Hopefully you won't need another warning. 
s3_client = session.client('s3')
func = getObjsFromMatchingS3Keys(AWS_ACCESS_KEY_ID, AWS_SECRET, bucket_name)

myFileObjs = []
for fName in get_matching_s3_keys(s3_client, bucket_name):
    myFileObjs.append(fName)

Side note: We needed to construct a SparkSession so that .toDF() would be available to the PipelinedRDD type due to the monkey patch, as explained here: PipelinedRDD object has no attribute toDF in PySpark

Finally, parallelize the S3 object keys with .mapPartitions(..) and the function we constructed:

pathToSave = r'absolute_path_to_your_desired_file.json'
sc.parallelize(myFileObjs) \
    .mapPartitions(lambda keys: func(keys)) \
    .map(lambda x: (x, )) \
    .toDF() \
    .toPandas() \
    .to_json(pathToSave)

There might be a more concise method to write to the target output file, but this code still works. Also, the purpose of the use of map(lambda x: (x, )) was to force schema inference, as mentioned here: Create Spark DataFrame - Cannot infer schema for type Forcing schema inference in this way may not be the best approach for all situations, but it was sufficient for this example.