I think my entire training data is being stored inside the graph which is hitting the 2gb limit. How can i use feed_dict in estimator API? FYI, I am using the tensorflow estimator API down the line for training my model.
Input Function:
def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''
X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
return X_train_tf
Error:
Traceback (most recent call last): File "/tmp/apprunner/.working/runtime/app/ae_python_tf.py", line 259, in AE_Regressor.train(lambda: input_fn(X_train,epochs,batch_size), hooks=[time_hist, logging_hook]) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1205, in _train_model return self._train_model_distributed(input_fn, hooks, saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1352, in _train_model_distributed saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec log_step_count_steps=log_step_count_steps) as mon_sess: File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession stop_grace_period_secs=stop_grace_period_secs) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 921, in init stop_grace_period_secs=stop_grace_period_secs) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 631, in init h.begin() File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py", line 543, in begin self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_cache.py", line 63, in get logdir, graph=ops.get_default_graph()) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 367, in init super(FileWriter, self).init(event_writer, graph, graph_def) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 83, in init self.add_graph(graph=graph, graph_def=graph_def) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 193, in add_graph true_graph_def = graph.as_graph_def(add_shapes=True) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3124, in as_graph_def result, _ = self._as_graph_def(from_version, add_shapes) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3082, in _as_graph_def c_api.TF_GraphToGraphDef(self._c_graph, buf) tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot serialize protocol buffer of type tensorflow.GraphDef as the serialized size (2838040852bytes) would be larger than the limit (2147483647 bytes)