Cloning a CNTK node to test it in isolation

2019-08-21 15:21发布

I am loading a model, now I want to test each node in isolation from the rest of the graph, and so I'm using the clone(CloneMethod.clone) method, but I'm finding this recreates the entire model. For example, when I clone the BatchNormalization layer I get this graph. So how do I clone BatchNormalization but disconnect it from the parent Minus node ?

enter image description here

标签: cntk
1条回答
太酷不给撩
2楼-- · 2019-08-21 15:32

You could name every node and then find them by their name. In the model below:

def create_model():
    with C.layers.default_options(initial_state=0.1):
        return C.layers.Sequential([
            C.layers.Embedding(emb_dim, name='embed'),
            C.layers.Recurrence(C.layers.LSTM(hidden_dim), go_backwards=False),
            C.layers.Dense(num_labels, name='classify')
        ])

You can call

z = create_model()
print(z.embed.E.shape)
print(z.classify.b.value)

You can try cloning a particular node using clone method. You can find examples of interrogating CNTK graphs by node names in many of the tutorials. You can also see how you can selectively work on a subgraph in the CNTK 206 tutorials.

Some sample code

import cntk as C
x = C.input_variable(5)
m = C.layers.Dense(4, name='foo')(x)
n = C.layers.Dense(3, name='baz')(m)
z = C.layers.Dense(2, name='bar')(n)
n_clone = z.baz.clone(method='share')

This will clone all the layers connected from n to the input x. One could just get the layer named baz, by declaring a new variable say y. y = C.input.variable(4) n_clone_baz = n_clone(y)

A more general clone method is available here.

The clone_method is what will get you to clone a sub-graph.

def clone_model(base_model, from_node_names, to_node_names, clone_method):
    from_nodes = [find_by_name(base_model, node_name) for node_name in from_node_names]
    if None in from_nodes:
        print("Error: could not find all specified 'from_nodes' in clone.")
    to_nodes = [find_by_name(base_model, node_name) for node_name in to_node_names]
    if None in to_nodes:
        print("Error: could not find all specified 'to_nodes' ...... ")

    input_placeholders = dict(zip(from_nodes, [placeholder() for x in from_nodes]))    
    cloned_net = combine(to_nodes).clone(clone_method, input_placeholders)
    return cloned_net
查看更多
登录 后发表回答