How to implement a convolutional neural network in PyTorch?

In PyTorch, you can use the Conv2d class from the torch.nn module to implement a convolutional neural network. Here is a simple example demonstrating how to implement a basic convolutional neural network in PyTorch.

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 创建一个SimpleCNN实例
model = SimpleCNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

In the example above, we have defined a simple convolutional neural network model, SimpleCNN, which includes two convolutional layers, ReLU activation function, max pooling layer, and two fully connected layers. We have defined the forward propagation process of the model in the forward method, and during the training process, we use cross-entropy loss function and Adam optimizer to train the model.

In this way, we can utilize PyTorch to implement, train, and fine-tune convolutional neural network models.

Leave a Reply 0

Your email address will not be published. Required fields are marked *