Preserving numpy view when pickling

2019-03-19 16:44发布

By default, pickling a numpy view array loses the view relationship, even if the array base is pickled too. My situation is that I have some complex container objects which are pickled. And in some cases, some contained data are views in some others. Saving a independent array of each view is not only a loss of space but also, the reloaded data have lost the view relationship.

A simple example would be (but in my case the container are more complex than a dictionary):

import numpy as np
import cPickle

tmp = np.zeros(2)
d1 = dict(a=tmp,b=tmp[:])    # d1 to be saved: b is a view on a

pickled = cPickle.dumps(d1)
d2 = cPickle.loads(pickled)  # d2 reloaded copy of d1 container

print 'd1 before:', d1
d1['b'][:] = 1
print 'd1 after: ', d1

print 'd2 before:', d2
d2['b'][:] = 1
print 'd2 after: ', d2

which would print:

d1 before: {'a': array([ 0.,  0.]), 'b': array([ 0.,  0.])}
d1 after:  {'a': array([ 1.,  1.]), 'b': array([ 1.,  1.])}
d2 before: {'a': array([ 0.,  0.]), 'b': array([ 0.,  0.])}
d2 after:  {'a': array([ 0.,  0.]), 'b': array([ 1.,  1.])} # not a view anymore

My question:

(1) Is there a way to preserve it? (2) (even better) is there a way to do it only if the base is pickled

For the (1) I think there may be some way by changing the __setstate__, __reduce_ex_, etc... of the view array. But I don't fill confident with these for now. For the (2) I have no idea.

1条回答
别忘想泡老子
2楼-- · 2019-03-19 17:32

This isn't done in NumPy proper, because it doesn't always make sense to pickle the base array, and pickle does not expose the ability to check if another object is also being pickled as part of its API.

But this sort of check can be done in a custom container for NumPy arrays. For example:

import numpy as np
import pickle

def byte_offset(array, source):
    return array.__array_interface__['data'][0] - np.byte_bounds(source)[0]

class SharedPickleList(object):
    def __init__(self, arrays):
        self.arrays = list(arrays)

    def __getstate__(self):
        unique_ids = {id(array) for array in self.arrays}
        source_arrays = {}
        view_tuples = {}
        for array in self.arrays:
            if array.base is None or id(array.base) not in unique_ids:
                # only use views if the base is also being pickled
                source_arrays[id(array)] = array
            else:
                view_tuples[id(array)] = (array.shape,
                                          array.dtype,
                                          id(array.base),
                                          byte_offset(array, array.base),
                                          array.strides)
        order = [id(array) for array in self.arrays]
        return (source_arrays, view_tuples, order)

    def __setstate__(self, state):
        source_arrays, view_tuples, order = state
        view_arrays = {}
        for k, view_state in view_tuples.items():
            (shape, dtype, source_id, offset, strides) = view_state
            buffer = source_arrays[source_id].data
            array = np.ndarray(shape, dtype, buffer, offset, strides)
            view_arrays[k] = array
        self.arrays = [source_arrays[i]
                       if i in source_arrays
                       else view_arrays[i]
                       for i in order]

# unit tests
def check_roundtrip(arrays):
    unpickled_arrays = pickle.loads(pickle.dumps(
        SharedPickleList(arrays))).arrays
    assert all(a.shape == b.shape and (a == b).all()
               for a, b in zip(arrays, unpickled_arrays))

indexers = [0, None, slice(None), slice(2), slice(None, -1),
            slice(None, None, -1), slice(None, 6, 2)]

source0 = np.random.randint(100, size=10)
arrays0 = [np.asarray(source0[k1]) for k1 in indexers]
check_roundtrip([source0] + arrays0)

source1 = np.random.randint(100, size=(8, 10))
arrays1 = [np.asarray(source1[k1, k2]) for k1 in indexers for k2 in indexers]
check_roundtrip([source1] + arrays1)

This results in significant space savings:

source = np.random.rand(1000)
arrays = [source] + [source[n:] for n in range(99)]
print(len(pickle.dumps(arrays, protocol=-1)))
# 766372
print(len(pickle.dumps(SharedPickleList(arrays), protocol=-1)))
# 11833
查看更多
登录 后发表回答