Using keras.utils.Sequence multiprocessing and dat

2020-07-23 05:52发布

问题:

I'm training a neural network with Keras with Tensorflow backend. Data set does not fit in RAM, therefore, I store it in the Mongo database and retrieve batches using subclass of keras.utils.Sequence.

Everything works fine, if I run model.fit_generator() with use_multiprocessing=False.

When I turn on multiprocessing, I get errors either during spawning of workers or in connection to the data base.

If I create a connection in __init__, I've got an exception whose text says something about errors in pickling lock objects. Sorry, I don't remember exactly. But the training even does not start.

If I create a connection in __get_item__, the training starts and runs some epochs, then I get errors [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted.

According to the pyMongo manuals, it is not fork-safe, and each child process must create its own connection to the data base. I use Windows, that does not use forks, but spawns processes instead, however, the difference does not matter here, IMHO.

This explains, why it is impossible to connect in __init__.

Here is one more quote from docs:

Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.

This explains errors in __get_item__.

However, it is unclear, how my class can understand that the Keras has created new process.

Here is the pseudocode of the last variant of my Sequence implementation (new connection on each request):

import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical

class MongoSequence(Sequence):
    def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
        self._train_set = train_set
        self._server = server
        self._db = database
        self.collection = collection
        self._batch_size = batch_size

        query = {}  # train_set query
        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = self._client[self._db]
        return _db[self._collection]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y

That is, on object construction, I retrieve all relevant ObjectIDs forming train set according to the criteria. Actual objects are retrieved from the database in calls to __getitem__. Their ObjectIDs are determined from a list slice.

This code that calls model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... ) spawns several python processes, each of which initializes Tensorflow backend, according to log messages, and the training starts.

But finally the exception is thrown from the function, called connect, somewhere deeply inside pymongo.

Unfortunately, I haven't stored call stack. The error is described above, I repeat: [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted.

My assumption is that this code creates too many connections to the server, therefore, connecting in __getitem__ is wrong.

Connection in the constructor is also wrong, since it is performed in main process, and Mongo docs directly object against it.

There is one more method in the Sequence class, on_epoch_end. But, I need a connection on epoch begin, not end.

Quote from Keras docs:

If you want to modify your dataset between epochs you may implement on_epoch_end

So, are there any recommendations? Docs are not very specific here.

回答1:

Look like I've found a solution. The solution is - track process id and reconnect when it changes

class MongoSequence(Sequence):
    def __init__(self, batch_size, train_set, query=None, server=None, database="database", collection="full_set"):
        self._server = server
        self._db = database
        self._collection_name = collection
        self._batch_size = batch_size
        self._query = query
        self._collection = self._connect()

        self._object_ids = [ smp["_id"] for uid in train_set for smp in self._collection.find(self._query, {'_id': True})]

        self._pid = os.getpid()
        del self._collection   #  to be sure, that we've disconnected
        self._collection = None

    def _connect(self):
        client = pymongo.MongoClient(self._server)
        db = client[self._db]
        return db[self._collection_name]

    def __len__(self):
        return int(np.ceil(len(self._object_ids) / float(self._batch_size)))

    def __getitem__(self, item):
        if self._collection is None or self._pid != os.getpid():
            self._collection = self._connect()
            self._pid = os.getpid()

        oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
        X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
        y = np.empty((len(oids), 2), dtype=np.float32)
        for i, oid in enumerate(oids):
            smp = self._connect().find({'_id': oid}).next()
            X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
            y[i] = to_categorical(not smp['result'], 2)
        return X, y


回答2:

Create your connection in on_epoch_end(), and make an explicit call to on_epoch_end() from the 'init()' method. This makes on_epoch_end() work, in practice, as if ti were "on epoch begin". (The end of each epoch, is the beginning of the next one. First epoch doesn't have an epoch before it, therefore the explicit call in the initialization.)