I am loading a model, now I want to test each node in isolation from the rest of the graph, and so I'm using the clone(CloneMethod.clone) method, but I'm finding this recreates the entire model. For example, when I clone the BatchNormalization layer I get this graph. So how do I clone BatchNormalization but disconnect it from the parent Minus node ?
可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
回答1:
You could name every node and then find them by their name. In the model below:
def create_model():
with C.layers.default_options(initial_state=0.1):
return C.layers.Sequential([
C.layers.Embedding(emb_dim, name='embed'),
C.layers.Recurrence(C.layers.LSTM(hidden_dim), go_backwards=False),
C.layers.Dense(num_labels, name='classify')
])
You can call
z = create_model()
print(z.embed.E.shape)
print(z.classify.b.value)
You can try cloning a particular node using clone method. You can find examples of interrogating CNTK graphs by node names in many of the tutorials. You can also see how you can selectively work on a subgraph in the CNTK 206 tutorials.
Some sample code
import cntk as C
x = C.input_variable(5)
m = C.layers.Dense(4, name='foo')(x)
n = C.layers.Dense(3, name='baz')(m)
z = C.layers.Dense(2, name='bar')(n)
n_clone = z.baz.clone(method='share')
This will clone all the layers connected from n
to the input x
. One could just get the layer named baz, by declaring a new variable say y
.
y = C.input.variable(4)
n_clone_baz = n_clone(y)
A more general clone method is available here.
The clone_method
is what will get you to clone a sub-graph.
def clone_model(base_model, from_node_names, to_node_names, clone_method):
from_nodes = [find_by_name(base_model, node_name) for node_name in from_node_names]
if None in from_nodes:
print("Error: could not find all specified 'from_nodes' in clone.")
to_nodes = [find_by_name(base_model, node_name) for node_name in to_node_names]
if None in to_nodes:
print("Error: could not find all specified 'to_nodes' ...... ")
input_placeholders = dict(zip(from_nodes, [placeholder() for x in from_nodes]))
cloned_net = combine(to_nodes).clone(clone_method, input_placeholders)
return cloned_net