I think my entire training data is being stored inside the graph which is hitting the 2gb limit. How can i use feed_dict in estimator API? FYI, I am using the tensorflow estimator API down the line for training my model.
Input Function:
def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''
X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
return X_train_tf
Error:
Traceback (most recent call last): File "/tmp/apprunner/.working/runtime/app/ae_python_tf.py", line 259, in AE_Regressor.train(lambda: input_fn(X_train,epochs,batch_size), hooks=[time_hist, logging_hook]) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1205, in _train_model return self._train_model_distributed(input_fn, hooks, saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1352, in _train_model_distributed saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec log_step_count_steps=log_step_count_steps) as mon_sess: File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession stop_grace_period_secs=stop_grace_period_secs) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 921, in init stop_grace_period_secs=stop_grace_period_secs) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 631, in init h.begin() File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py", line 543, in begin self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_cache.py", line 63, in get logdir, graph=ops.get_default_graph()) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 367, in init super(FileWriter, self).init(event_writer, graph, graph_def) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 83, in init self.add_graph(graph=graph, graph_def=graph_def) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 193, in add_graph true_graph_def = graph.as_graph_def(add_shapes=True) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3124, in as_graph_def result, _ = self._as_graph_def(from_version, add_shapes) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3082, in _as_graph_def c_api.TF_GraphToGraphDef(self._c_graph, buf) tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot serialize protocol buffer of type tensorflow.GraphDef as the serialized size (2838040852bytes) would be larger than the limit (2147483647 bytes)
I'm normally against quoting documentation verbatim, but this is explained word-by-word in the TF documentation and I can't find a way to put it better than they already do:
(Code and text both taken from the link above, removed one
assert
in the code that was't relevant to the issue)Update
If you're trying to use this with the Estimator API, you're out of luck. From the same linked page, a few sections above the one quoted before:
This, as you noted in the comment, is because the Estimator API hides away the
sess.run()
calls where you need to pass thefeed_dict
for your iterator.