I have the same problem as How can I load and use a PyTorch (.pth.tar) model which does not have an accepted answer or one I can figure out how to follow the advice given.
I'm new to PyTorch. I am trying to load the pretrained PyTorch model referenced here: https://github.com/macaodha/inat_comp_2018
I'm pretty sure I am missing some glue.
# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])
def image_loader(image_name):
"""load image, returns cuda tensor"""
image = Image.open(image_name)
image = loader(image).float()
image = Variable(image, requires_grad=True)
image = image.unsqueeze(0)
return image.cpu() #assumes that you're using CPU
image = image_loader("test-image.jpg")
Produces the error:
in ()
----> 1 model.predict(image)
AttributeError: 'dict' object has no attribute 'predict
Problem
Your model
isn't actually a model. When it is saved, it contains not only the parameters, but also other information about the model as a form somewhat similar to a dict.
Therefore, torch.load("iNat_2018_InceptionV3.pth.tar")
simply returns dict
, which of course does not have an attribute called predict
.
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict
Solution
What you need to do first in this case, and in general cases, is to instantiate your desired model class, as per the official guide "Load models".
# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.
However, directly inputing the model['state_dict']
will raise some errors regarding mismatching shapes of Inception3
's parameters.
It is important to know what was changed to the Inception3
after its instantiation. Luckily, you can find that in the original author's train_inat.py
.
# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False
Now that we know what to change, lets make some modification to our first try.
# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.
And there you go with successfully loaded model!