I'm training a neural network with Keras with Tensorflow backend. Data set does not fit in RAM, therefore, I store it in the Mongo database and retrieve batches using subclass of keras.utils.Sequence
.
Everything works fine, if I run model.fit_generator()
with use_multiprocessing=False
.
When I turn on multiprocessing, I get errors either during spawning of workers or in connection to the data base.
If I create a connection in __init__
, I've got an exception whose text says something about errors in pickling lock objects. Sorry, I don't remember exactly. But the training even does not start.
If I create a connection in __get_item__
, the training starts and runs some epochs, then I get errors [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
.
According to the pyMongo manuals, it is not fork-safe, and each child process must create its own connection to the data base. I use Windows, that does not use forks, but spawns processes instead, however, the difference does not matter here, IMHO.
This explains, why it is impossible to connect in __init__
.
Here is one more quote from docs:
Create this client once for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient.
This explains errors in __get_item__
.
However, it is unclear, how my class can understand that the Keras has created new process.
Here is the pseudocode of the last variant of my Sequence implementation (new connection on each request):
import pymongo
import numpy as np
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical
class MongoSequence(Sequence):
def __init__(self, train_set, batch_size, server=None, database="database", collection="full_set"):
self._train_set = train_set
self._server = server
self._db = database
self.collection = collection
self._batch_size = batch_size
query = {} # train_set query
self._object_ids = [ smp["_id"] for uid in train_set for smp in self._connect().find(query, {'_id': True})]
def _connect(self):
client = pymongo.MongoClient(self._server)
db = self._client[self._db]
return _db[self._collection]
def __len__(self):
return int(np.ceil(len(self._object_ids) / float(self._batch_size)))
def __getitem__(self, item):
oids = self._object_ids[item * self._batch_size: (item+1) * self._batch_size]
X = np.empty((len(oids), IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS), dtype=np.float32)
y = np.empty((len(oids), 2), dtype=np.float32)
for i, oid in enumerate(oids):
smp = self._connect().find({'_id': oid}).next()
X[i, :, :, :] = pickle.loads(smp['frame']).astype(np.float32)
y[i] = to_categorical(not smp['result'], 2)
return X, y
That is, on object construction, I retrieve all relevant ObjectIDs
forming train set according to the criteria. Actual objects are retrieved from the database in calls to __getitem__
. Their ObjectIDs
are determined from a list slice.
This code that calls model.fit_generator(generator=MongoSequence(train_ids, batch_size=10), ... )
spawns several python processes, each of which initializes Tensorflow backend, according to log messages, and the training starts.
But finally the exception is thrown from the function, called connect
, somewhere deeply inside pymongo
.
Unfortunately, I haven't stored call stack. The error is described above, I repeat: [WinError 10048] Only one usage of each socket address (protocol/network address/port) is normally permitted
.
My assumption is that this code creates too many connections to the server, therefore, connecting in __getitem__
is wrong.
Connection in the constructor is also wrong, since it is performed in main process, and Mongo docs directly object against it.
There is one more method in the Sequence
class, on_epoch_end
. But, I need a connection on epoch begin, not end.
Quote from Keras docs:
If you want to modify your dataset between epochs you may implement on_epoch_end
So, are there any recommendations? Docs are not very specific here.
Create your connection in
on_epoch_end()
, and make an explicit call toon_epoch_end()
from the 'init()' method. This makeson_epoch_end()
work, in practice, as if ti were "on epoch begin". (The end of each epoch, is the beginning of the next one. First epoch doesn't have an epoch before it, therefore the explicit call in the initialization.)Look like I've found a solution. The solution is - track process id and reconnect when it changes