How can a loss function be defined in Torch?
In Torch, a loss function is usually defined by inheriting from the nn.Module class. Here is an example:
import torch
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
def forward(self, output, target):
loss = torch.mean((output - target) ** 2) # 以均方误差为例
return loss
The example above defines a custom loss function class named CustomLoss, which takes the model’s output and target value as inputs in its forward method to calculate the loss value. The calculation method used here is mean square error, but different loss functions can be customized as needed.