I am solving a text classification problem. I defined my classifier using the Estimator
class with my own model_fn
. I would like to use Google's pre-trained word2vec
embedding as initial values and then further optimise it for the task at hand.
I saw this post: Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
which explains how to go about it in 'raw' TensorFlow code. However, I would really like to use the Estimator
class.
As an extension, I would like to then train this code on Cloud ML Engine, is there a good way of passing in the fairly large file with initial values?
Let's say we have something like:
def build_model_fn():
def _model_fn(features, labels, mode, params):
input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
#... what goes here to initialize W
embedded = tf.nn.embedding_lookup(W, input_layer)
...
return predictions
estimator = tf.contrib.learn.Estimator(
model_fn=build_model_fn(),
model_dir=MODEL_DIR,
params=params)
estimator.fit(input_fn=read_data, max_steps=2500)
Embeddings are typically large enough that the only viable approach is using them to initialize a
tf.Variable
in your graph. This will allow you to take advantage of param servers in distributed, etc.For this (and anything else), I would recommend you use the new "core" estimator,
tf.estimator.Estimator
as this will make things much easier.From the answer in the link you provided, and knowing that we want a variable not a constant, we can either take approach:
(2) Initialize the variable using a feed dict, or (3) Load the variable from a checkpoint
I'll cover option (3) first since it's much easier, and better:
In your
model_fn
, simply initialize a variable using theTensor
returned by atf.contrib.framework.load_variable
call. This requires:The code is pretty simple:
However this approach won't work for you if your embeddings weren't produced by another TF model, hence option (2).
For (2), we need to use
tf.train.Scaffold
which is essentially a configuration object that holds all the options for starting atf.Session
(which estimator intentionally hides for lots of reasons).You may specify a
Scaffold
in thetf.train.EstimatorSpec
you return in yourmodel_fn
.We create a placeholder in our model_fn, and make it the initializer operation for our embedding variable, then pass an
init_feed_dict
via theScaffold
. e.g.What's happening here is the
init_feed_dict
will populate the values of theembed_ph
placeholder at runtime, which will then allow theembeddings.initialization_op
(assignment of the placeholder), to run.