How to concatenate PyTorch tensors?

In PyTorch, you can concatenate tensors using the torch.cat() function.

The syntax of the torch.cat() function is as follows:

torch.cat(tensors, dim=0, out=None)

In this case, the parameter “tensors” is a sequence of tensors that represent the tensors to be concatenated; “dim” specifies the dimension along which to concatenate, defaulting to 0 (concatenating along the row direction); “out” is an optional output tensor that represents the concatenated result.

Here is an example code using the torch.cat() function for tensor concatenation.

import torch

# 创建两个张量
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])

# 沿着行的方向拼接张量
result = torch.cat((tensor1, tensor2), dim=0)

print(result)

The running result is:

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

In the example above, we first create two tensors, tensor1 and tensor2. We then concatenate these two tensors along the row dimension using the torch.cat() function to create a new tensor, result. Finally, we print out the concatenated result.

bannerAds