How to use pre-trained models in PyTorch?
Pre-trained models can be used in PyTorch by utilizing the models module in the torchvision library. This module contains popular pre-trained models such as ResNet, VGG, and AlexNet. Below is an example of using a pre-trained ResNet model.
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的ResNet模型
model = models.resnet18(pretrained=True)
model.eval()
# 加载一张图片进行推理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = Image.open('image.jpg')
img = transform(img)
img = img.unsqueeze(0) # 添加一个维度作为batch
# 进行推理
output = model(img)
In the example above, we first load the pre-trained ResNet model and set it to evaluation mode. Then, we load an image, preprocess it, and finally infer the output using the model. It is important to note that before inferring, we also need to call model.eval() to set the model to evaluation mode.