I recently read this paper which introduces a process called "Warm-Up" (WU), which consists in multiplying the loss in the KL-divergence by a variable whose value depends on the number of epoch (it evolves linearly from 0 to 1)
I was wondering if this is the good way to do that:
beta = K.variable(value=0.0)
def vae_loss(x, x_decoded_mean):
# cross entropy
xent_loss = K.mean(objectives.categorical_crossentropy(x, x_decoded_mean))
# kl divergence
for k in range(n_sample):
epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
std=1.0) # used for every z_i sampling
# Sample several layers of latent variables
for mean, var in zip(means, variances):
z_ = mean + K.exp(K.log(var) / 2) * epsilon
# build z
try:
z = tf.concat([z, z_], -1)
except NameError:
z = z_
except TypeError:
z = z_
# sum loss (using a MC approximation)
try:
loss += K.sum(log_normal2(z_, mean, K.log(var)), -1)
except NameError:
loss = K.sum(log_normal2(z_, mean, K.log(var)), -1)
print("z", z)
loss -= K.sum(log_stdnormal(z) , -1)
z = None
kl_loss = loss / n_sample
print('kl loss:', kl_loss)
# result
result = beta*kl_loss + xent_loss
return result
# define callback to change the value of beta at each epoch
def warmup(epoch):
value = (epoch/10.0) * (epoch <= 10.0) + 1.0 * (epoch > 10.0)
print("beta:", value)
beta = K.variable(value=value)
from keras.callbacks import LambdaCallback
wu_cb = LambdaCallback(on_epoch_end=lambda epoch, log: warmup(epoch))
# train model
vae.fit(
padded_X_train[:last_train,:,:],
padded_X_train[:last_train,:,:],
batch_size=batch_size,
nb_epoch=nb_epoch,
verbose=0,
callbacks=[tb, wu_cb],
validation_data=(padded_X_test[:last_test,:,:], padded_X_test[:last_test,:,:])
)
This will not work. I tested it to figure out exactly why it was not working. The key thing to remember is that Keras creates a static graph once at the beginning of training.
Therefore, the
vae_loss
function is called only once to create the loss tensor, which means that the reference to thebeta
variable will remain the same every time the loss is calculated. However, yourwarmup
function reassigns beta to a newK.variable
. Thus, thebeta
that is used for calculating loss is a differentbeta
than the one that gets updated, and the value will always be 0.It is an easy fix. Just change this line in your
warmup
callback:beta = K.variable(value=value)
to:
K.set_value(beta, value)
This way the actual value in
beta
gets updated "in place" rather than creating a new variable, and the loss will be properly re-calculated.