Tensorflow Estimator Graph Size Limitation for lar

2019-09-20 07:11发布

I think my entire training data is being stored inside the graph which is hitting the 2gb limit. How can i use feed_dict in estimator API? FYI, I am using the tensorflow estimator API down the line for training my model.

Input Function:

def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''

X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
    X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
    X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
    return X_train_tf

Error:

Traceback (most recent call last): File "/tmp/apprunner/.working/runtime/app/ae_python_tf.py", line 259, in AE_Regressor.train(lambda: input_fn(X_train,epochs,batch_size), hooks=[time_hist, logging_hook]) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 354, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1205, in _train_model return self._train_model_distributed(input_fn, hooks, saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1352, in _train_model_distributed saving_listeners) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py", line 1468, in _train_with_estimator_spec log_step_count_steps=log_step_count_steps) as mon_sess: File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 504, in MonitoredTrainingSession stop_grace_period_secs=stop_grace_period_secs) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 921, in init stop_grace_period_secs=stop_grace_period_secs) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py", line 631, in init h.begin() File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py", line 543, in begin self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_cache.py", line 63, in get logdir, graph=ops.get_default_graph()) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 367, in init super(FileWriter, self).init(event_writer, graph, graph_def) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 83, in init self.add_graph(graph=graph, graph_def=graph_def) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py", line 193, in add_graph true_graph_def = graph.as_graph_def(add_shapes=True) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3124, in as_graph_def result, _ = self._as_graph_def(from_version, add_shapes) File "/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3082, in _as_graph_def c_api.TF_GraphToGraphDef(self._c_graph, buf) tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot serialize protocol buffer of type tensorflow.GraphDef as the serialized size (2838040852bytes) would be larger than the limit (2147483647 bytes)

1条回答
唯我独甜
2楼-- · 2019-09-20 07:29

I'm normally against quoting documentation verbatim, but this is explained word-by-word in the TF documentation and I can't find a way to put it better than they already do:

Note that [using Dataset.from_tensor_slices() on features and labels numpy arrays] will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer.

As an alternative, you can define the Dataset in terms of tf.placeholder() tensors, and feed the NumPy arrays when you initialize an Iterator over the dataset.

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

(Code and text both taken from the link above, removed one assert in the code that was't relevant to the issue)


Update

If you're trying to use this with the Estimator API, you're out of luck. From the same linked page, a few sections above the one quoted before:

Note: Currently, one-shot iterators are the only type that is easily usable with an Estimator.

This, as you noted in the comment, is because the Estimator API hides away the sess.run() calls where you need to pass the feed_dict for your iterator.

查看更多
登录 后发表回答