How to save a model when using MXnet

2019-06-27 13:58发布

I am using MXnet for training a CNN (in R) and I can train the model without any error with the following code:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

But as this process is time-consuming, I run it on a server during the night and I want to save the model for the purpose of using it after finishing the training.

I used:

save(list = ls(), file="mymodel.RData")

and

mx.model.save("mymodel", 10)

But none of them can save the model! for example when I load the "mymodel.RData", I can not predict the labels for the test set!

Another example is when I load the "mymodel.RData" and try to plot it with the following code:

graph.viz(model$symbol$as.json())

I get the following error:

Error in model$symbol$as.json() : external pointer is not valid

Can anybody give me a solution for saving and then loading this model for future use?

Thanks

3条回答
ら.Afraid
2楼-- · 2019-06-27 14:26

The best practice for saving a snapshot of your training progress is to use save_snapshot (http://mxnet.io/api/python/module.html#mxnet.module.Module.save_checkpoint) as part of the callback after every epoch training. In R the equivalent command is probably mx.callback.save.checkpoint, but I'm not using R and not sure about it usage.

Using these snapshots can also allow you to take advantage of the low cost option of using AWS Spot market (https://aws.amazon.com/ec2/spot/pricing/ ), which for example now offers and instance with 16 K80 GPUs for $3.8/hour compare to the on-demand price of $14.4. Such 80%-90% discount is common in the spot market and can optimize the speed and cost of your training, as long as you use these snapshots correctly.

查看更多
我只想做你的唯一
3楼-- · 2019-06-27 14:28

A mxnet model is an R list, but its first component is not an R object but a C++ pointer and can't be saved and reloaded as an R object. Therefore, the model needs to be serialized to behave as an actual R object. The serialized object is also a list, but its first object is a text containing model information.

To save a model:

modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")

To retrieve it and use it again:

load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)
查看更多
【Aperson】
4楼-- · 2019-06-27 14:39

You can save the model by

model <- mx.model.FeedForward.create(symbol=network,
                                 X=train.iter,
                                 ctx=mx.gpu(0),
                                 num.round=20,
                                 array.batch.size=batch.size,
                                 learning.rate=0.1,
                                 momentum=0.1,  
                                 eval.metric=mx.metric.accuracy,
                                 wd=0.001,
                                 epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                 batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
查看更多
登录 后发表回答