Tensorflow 2.0 Keras is training 4x slower than 2.

2020-06-09 07:34发布

We recently switched to Keras for TF 2.0, but when we compared it to the DNNClassifier Estimator on 2.0, we experienced around 4x slower speeds with Keras. But I cannot for the life of me figure out why this is happening. The rest of the code for both are identical, using an input_fn() that returns the same tf.data.Dataset, and using identical feature_columns. Been struggling with this problem for days now. Any help would be greatly greatly appreciated. Thank you

Estimator code:

estimator = tf.estimator.DNNClassifier(
        feature_columns = feature_columns,
        hidden_units = [64,64],
        activation_fn = tf.nn.relu,
        optimizer = 'Adagrad',
        dropout = 0.4,
        n_classes = len(vocab),
        model_dir = model_dir,
        batch_norm = false)

estimator.train(input_fn=train_input_fn, steps=400)

Keras code:

feature_layer = tf.keras.layers.DenseFeatures(feature_columns);

model = tf.keras.Sequential([
        feature_layer,
        layers.Dense(64, input_shape = (len(vocab),), activation = tf.nn.relu),
        layers.Dropout(0.4),
        layers.Dense(64, activation = tf.nn.relu),
        layers.Dropout(0.4),
        layers.Dense(len(vocab), activation = 'softmax')]);

model.compile(
        loss = 'sparse_categorical_crossentropy',
        optimizer = 'Adagrad'
        distribute = None)

model.fit(x = train_input_fn(),
          epochs = 1,
          steps_per_epoch = 400,
          shuffle = True)

UPDATE: To test further, I wrote a custom subclassed Model (See: Get Started For Experts), which runs faster than Keras but slower than Estimators. If Estimator trains in 100 secs, the custom model takes approx ~180secs, and Keras approx ~350secs. An interesting note is that Estimator runs slower with Adam() than Adagrad() while Keras seems to run faster. With Adam() Keras takes less than twice as long as DNNClassifier. Assuming I didn't mess up the custom code, I'm beginning to think that DNNClassifier just has a lot of backend optimization / efficiencies that make it run faster than Keras.

Custom code:

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.features = layers.DenseFeatures(feature_columns, trainable=False)
    self.dense = layers.Dense(64, activation = 'relu')
    self.dropout = layers.Dropout(0.4)
    self.dense2 = layers.Dense(64, activation = 'relu')
    self.dropout2 = layers.Dropout(0.4)
    self.softmax = layers.Dense(len(vocab_of_codes), activation = 'softmax')

  def call(self, x):
    x = self.features(x)
    x = self.dense(x)
    x = self.dropout(x)
    x = self.dense2(x)
    x = self.dropout2(x)
    return self.softmax(x)

model = MyModel()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adagrad()

@tf.function
def train_step(features, label):
  with tf.GradientTape() as tape:
    predictions = model(features)
    loss = loss_object(label, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

itera = iter(train_input_fn())
for i in range(400):
  features, labels = next(itera)
  train_step(features, labels)

UPDATE: It possibly seems to be the dataset. When I print a row of the dataset within the train_input_fn(), in estimators, it prints out the non-eager Tensor definition. In Keras, it prints out the eager values. Going through the Keras backend code, when it receives a tf.data.dataset as input, it handles it eagerly (and ONLY eagerly), which is why it was crashing whenever I used tf.function on the train_input_fn(). Basically, my guess is DNNClassifier is training faster than Keras because it runs more dataset code in graph mode. Will post any updates/finds.

1条回答
贼婆χ
2楼-- · 2020-06-09 07:52

I believe it is slower because it is not being executed on the graph. In order to execute on the graph in TF2 you'll need a function decorated with the tf.function decorator. Check out this section for ideas on how to restructure your code.

查看更多
登录 后发表回答