PyTorchのviewメソッドはどのように使用するのですか?

PyTorchでは、view()関数を使用してテンソルの形状を調整することができます。その使い方は以下の通りです:

output = input.view(*shape)

こちらのinputは入力テンソルであり、shapeは指定された新しい形状を示すタプルです。詳しくは、

  1. 各shape内の要素は、具体的な次元のサイズまたは-1(他の次元のサイズに基づいて自動的に計算)であることができます。
  2. 調整されたテンソルと元のテンソルは、同じデータを指しているため、メモリスペースを共有しています。

以下はいくつかの例です。

import torch

x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 将x的形状调整为(3, 2)
output = x.view(3, 2)
print(output)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

# 将x的形状调整为(6, -1),其中-1表示自动计算
output = x.view(6, -1)
print(output)
# 输出:
# tensor([[1],
#         [2],
#         [3],
#         [4],
#         [5],
#         [6]])

# 将x的形状调整为(1, 6)
output = x.view(1, 6)
print(output)
# 输出:
# tensor([[1, 2, 3, 4, 5, 6]])

修正された形状は元のテンソルの要素の総数と一致する必要があります。そうでないとエラーが発生します。

bannerAds