Torchにおいて損失関数を定義する方法は何ですか?

Torchで損失関数を定義する場合、一般的にはnn.Moduleクラスを継承して実装します。以下は例です:

import torch
import torch.nn as nn

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, target):
        loss = torch.mean((output - target) ** 2)  # 以均方误差为例
        return loss

上記の例では、CustomLossというカスタム損失関数クラスが定義されており、そのforwardメソッドはモデルの出力outputと目標値targetを入力として受け取り、損失値を計算します。ここで使用されているのは平均二乗誤差を損失関数として計算する方法であり、必要に応じて異なる損失関数をカスタマイズすることができます。

bannerAds