How do I visualize a net in Pytorch?

2020-02-26 04:52发布

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

I want to visualize resnet from the pytorch models. How can I do it? I tried to use torchviz but it gives an error:

'ResNet' object has no attribute 'grad_fn'

4条回答
混吃等死
2楼-- · 2020-02-26 05:03

You can have a look at PyTorchViz (https://github.com/szagoruyko/pytorchviz), "A small package to create visualizations of PyTorch execution graphs and traces."

Example PyTorchViz visualization

查看更多
迷人小祖宗
3楼-- · 2020-02-26 05:03

Here is how you do it with torchviz if you want to save the image:

# http://www.bnikolic.co.uk/blog/pytorch-detach.html

import torch
from torchviz import make_dot

x=torch.ones(10, requires_grad=True)
weights = {'x':x}

y=x**2
z=x**3
r=(y+z).sum()

make_dot(r).render("attached", format="png")

screenshot of image you get:

enter image description here

source: http://www.bnikolic.co.uk/blog/pytorch-detach.html

查看更多
Emotional °昔
4楼-- · 2020-02-26 05:13

You can use TensorBoard for visualization. TensorBoard is now fully supported in PyTorch version 1.2.0. More info: https://pytorch.org/docs/stable/tensorboard.html

查看更多
孤傲高冷的网名
5楼-- · 2020-02-26 05:19

make_dot expects a variable (i.e., tensor with grad_fn), not the model itself.
try:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
查看更多
登录 后发表回答