PyTorchで独自データセットをロードする方法
独自のデータセットを読み込む場合は、PyTorch でカスタムデータセットクラスを作成できます。
最初に、以下で必要なライブラリとモジュールをインポートする必要があります:
import torch
from torch.utils.data import Dataset, DataLoader
次に、`torch.utils.data.Dataset` クラスを継承したカスタムデータセットクラスを作成します。このクラスでは、`__init__`、`__len__`、`__getitem__` メソッドを実装する必要があります。`__init__` メソッドはデータセットの初期化に使用され、`__len__` メソッドはデータセットのサイズを返します。`__getitem__` メソッドは、指定されたインデックスに対応するデータを取得するために使用されます。
class CustomDataset(Dataset):
def __init__(self, ...):
# 初始化数据集
...
def __len__(self):
# 返回数据集大小
...
def __getitem__(self, index):
# 获取指定索引的数据
...
__getitem__メソッドでは、インデックスに応じて対応するデータを読み込み、データとラベルを返します。torchvision.transformsモジュールを使用して、データの前処理を行うことができます。
from torchvision import transforms
class CustomDataset(Dataset):
def __init__(self, ...):
# 初始化数据集
...
# 定义数据预处理
self.transform = transforms.Compose([
transforms.ToTensor(), # 将数据转为Tensor
transforms.Normalize((0.5,), (0.5,)) # 数据标准化
])
def __len__(self):
# 返回数据集大小
...
def __getitem__(self, index):
# 获取指定索引的数据
...
# 加载数据和标签
data, label = ...
# 对数据进行预处理
data = self.transform(data)
return data, label
最後に、DataLoaderクラスを用いてデータセットを読み込みます。DataLoaderはバッチごとにデータを読み込み、データのイテレータを提供します。
dataset = CustomDataset(...)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
上記の手順を実行すれば、自分のデータセットを読み込んでDataLoaderを使用してデータとラベルを取得できます。