How to load and use a pretained PyTorch InceptionV

2019-07-15 02:35发布

I have the same problem as How can I load and use a PyTorch (.pth.tar) model which does not have an accepted answer or one I can figure out how to follow the advice given.

I'm new to PyTorch. I am trying to load the pretrained PyTorch model referenced here: https://github.com/macaodha/inat_comp_2018

I'm pretty sure I am missing some glue.

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')

# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  
    return image.cpu()  #assumes that you're using CPU

image = image_loader("test-image.jpg")

Produces the error:

in () ----> 1 model.predict(image)

AttributeError: 'dict' object has no attribute 'predict

1条回答
来,给爷笑一个
2楼-- · 2019-07-15 03:00

Problem

Your model isn't actually a model. When it is saved, it contains not only the parameters, but also other information about the model as a form somewhat similar to a dict.

Therefore, torch.load("iNat_2018_InceptionV3.pth.tar") simply returns dict, which of course does not have an attribute called predict.

model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict

Solution

What you need to do first in this case, and in general cases, is to instantiate your desired model class, as per the official guide "Load models".

# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

However, directly inputing the model['state_dict'] will raise some errors regarding mismatching shapes of Inception3's parameters.

It is important to know what was changed to the Inception3 after its instantiation. Luckily, you can find that in the original author's train_inat.py.

# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False

Now that we know what to change, lets make some modification to our first try.

# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

And there you go with successfully loaded model!

查看更多
登录 后发表回答