Unpickling mid-stream (python)

2019-07-20 19:47发布

问题:

I am writing scripts to process (very large) files by repeatedly unpickling objects until EOF. I would like to partition the file and have separate processes (in the cloud) unpickle and process separate parts.

However my partitioner is not intelligent, it does not know about the boundaries between pickled objects in the file (since those boundaries depend on the object types being pickled, etc.).

Is there a way to scan a file for a "start pickled object" sentinel? The naive way would be to attempt unpickling at successive byte offsets until an object is successfully pickled, but that yields unexpected errors. It seems that for certain combinations of input, the unpickler falls out of sync and returns nothing for the rest of the file (see code below).

import cPickle
import os

def stream_unpickle(file_obj):
    while True:
        start_pos = file_obj.tell()
        try:
            yield cPickle.load(file_obj)
        except (EOFError, KeyboardInterrupt):
            break
        except (cPickle.UnpicklingError, ValueError, KeyError, TypeError, ImportError):
            file_obj.seek(start_pos+1, os.SEEK_SET)

if __name__ == '__main__':
    import random
    from StringIO import StringIO

    # create some data
    sio = StringIO()
    [cPickle.dump(random.random(), sio, cPickle.HIGHEST_PROTOCOL) for _ in xrange(1000)]
    sio.flush()

    # read from subsequent offsets and find discontinuous jumps in object count
    size = sio.tell()
    last_count = None
    for step in xrange(size):
        sio.seek(step, os.SEEK_SET)
        count = sum(1 for _ in stream_unpickle(file_obj))
        if last_count is None or count == last_count - 1:
            last_count = count
        elif count != last_count:
            # if successful, these should never print (but they do...)
            print '%d elements read from byte %d' % (count, step)
            print '(%d elements read from byte %d)' % (last_count, step-1)
            last_count = count

回答1:

The pickletools module has a dis function that shows the opcodes. It shows that there is a STOP opcode that you may be scan for:

>>> import pickle, pickletools, StringIO
>>> s = StringIO.StringIO()
>>> pickle.dump('abc', s)
>>> p = s.getvalue()
>>> pickletools.dis(p)
    0: S    STRING     'abc'
    7: p    PUT        0
   10: .    STOP
highest protocol among opcodes = 0

Note, using the STOP opcode is a bit tricky because the codes are of variable length, but it may serve as a useful hint about where the cutoffs are.

If you control the pickling step on the other end, then you can improve the situation by adding your own unambiguous alternative separator:

>>> sep = '\xDE\xAD\xBE\xEF'
>>> s = StringIO.StringIO()
>>> pickle.dump('abc', s)
>>> s.write(sep)
>>> pickle.dump([10, 20], s)
>>> s.write(sep)
>>> pickle.dump('def', s)
>>> s.write(sep)
>>> pickle.dump([30, 40], s)
>>> p = s.getvalue()

Before unpacking, split into separate pickles using the known separator:

>>> for pick in p.split(sep):
        print pickle.loads(pick)

abc
[10, 20]
def
[30, 40]


回答2:

In the pickled file, some opcodes have an argument -- a data value that follows the opcode. The data values vary in length, and can contain bytes identical to opcodes. Therefore, if you start reading the file from an arbitrary position, you have no way of knowing if you are looking at an opcode or in the middle of an argument. You must read the file from beginning and parse the opcodes.

I cooked up this function that skips one pickle from a file, i.e. reads it and parses opcodes, but does not construct the objects. It seems slightly faster than cPickle.loads on some files I have. You could rewrite this in C for more speed. (after testing this properly)

Then, you can make one pass over the whole file to get the seek position of each pickle.

from pickletools import code2op, UP_TO_NEWLINE, TAKEN_FROM_ARGUMENT1, TAKEN_FROM_ARGUMENT4   
from marshal import loads as mloads

def skip_pickle(f):
    """Skip one pickle from file.

    'f' is a file-like object containing the pickle.

    """
    while True:
        code = f.read(1)
        if not code:
            raise EOFError
        opcode = code2op[code]
        if opcode.arg is not None:
            n = opcode.arg.n
            if n > 0:
                f.read(n)
            elif n == UP_TO_NEWLINE:
                f.readline()
            elif n == TAKEN_FROM_ARGUMENT1:
                n = ord(f.read(1))
                f.read(n)
            elif n == TAKEN_FROM_ARGUMENT4:
                n = mloads('i' + f.read(4))
                f.read(n)
        if code == '.':
            break        


回答3:

Sorry to answer my own question, and thanks to @RaymondHettinger for the idea of adding sentinels.

Here's what worked for me. I created readers and writers that use a sentinel '#S' followed by a data block length at the beginning of each 'record'. The writer has to take care to find any occurrences of '#' in the data being written and double them (into '##'). The reader then uses a look-behind regex to find sentinels, distinct from any matching values that might be in the original stream, and also verify the number of bytes between this sentinel and the subsequent one.

RecordWriter is a context manager (so multiple calls to write() can be encapsulated into a single record if needed). RecordReader is a generator.

Not sure how this is on performance. Any faster/elegant-er solutions are welcome.

import re
import cPickle
from functools import partial
from cStringIO import StringIO

SENTINEL = '#S'

# when scanning look for #S, but NOT ##S
sentinel_pattern = '(?<!#)#S' # uses negative look-behind
sentinel_re = re.compile(sentinel_pattern)
find_sentinel = sentinel_re.search

# when writing replace single # with double ##
write_pattern = '#'
write_re = re.compile(write_pattern)
fix_write = partial(write_re.sub, '##')

# when reading, replace double ## with single #
read_pattern = '##'
read_re = re.compile(read_pattern)
fix_read = partial(read_re.sub, '#') 

class RecordWriter(object):
    def __init__(self, stream):
        self._stream = stream
        self._write_buffer = None

    def __enter__(self):
        self._write_buffer = StringIO()
        return self

    def __exit__(self, et, ex, tb):
        if self._write_buffer.tell():
            self._stream.write(SENTINEL) # start
            cPickle.dump(self._write_buffer.tell(), self._stream, cPickle.HIGHEST_PROTOCOL) # byte length of user's original data
            self._stream.write(fix_write(self._write_buffer.getvalue()))
            self._write_buffer = None
        return False

    def write(self, data):
        if not self._write_buffer:
            raise ValueError("Must use StreamWriter as a context manager")
        self._write_buffer.write(data)

class BadBlock(Exception): pass

def verify_length(block):
    fobj = StringIO(block)
    try:
        stated_length = cPickle.load(fobj)
    except (ValueError, IndexError, cPickle.UnpicklingError):
        raise BadBlock
    data = fobj.read()
    if len(data) != stated_length:
        raise BadBlock
    return data

def RecordReader(stream):
    ' Read one record '
    accum = StringIO()
    seen_sentinel = False
    data = ''
    while True:
        m = find_sentinel(data)
        if not m:
            if seen_sentinel:
                accum.write(data)
            data = stream.read(80)
            if not data:
                if accum.tell():
                    try: yield verify_length(fix_read(accum.getvalue()))
                    except BadBlock: pass
                return
        else:
            if seen_sentinel:
                accum.write(data[:m.start()])
                try: yield verify_length(fix_read(accum.getvalue()))
                except BadBlock: pass
                accum = StringIO()
            else:
                seen_sentinel = True
            data = data[m.end():] # toss

if __name__ == '__main__':
    import random

    stream = StringIO()
    data = [str(random.random()) for _ in xrange(3)]
    # test with a string containing sentinel and almost-sentinel
    data.append('abc12#jeoht38#SoSooihetS#')
    count = len(data)
    for i in data:
        with RecordWriter(stream) as r:
            r.write(i)

    size = stream.tell()
    start_pos = random.random() * size
    stream.seek(start_pos, os.SEEK_SET)
    read_data = [s for s in RecordReader(stream)]
    print 'Original data: ', data
    print 'After seeking to %d, RecordReader returned: %s' % (start_pos, read_data)