Pytorch
torch.cat과 torch.stack
술임
2023. 11. 27. 14:56
torch.cat
torch.cat은 주어진 텐서들을 주어진 차원에 맞춰 합쳐주는 함수
import torch
x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]
output1 = torch.cat([x,y], dim=1) #[M, N+N, K]
output2 = torch.cat([x,y], dim=2) #[M, N, K+K]
import torch
t1 = torch.rand((6, 32))
t2 = torch.rand((4, 32))
torch.cat((t1, t2), dim=0).shape
>>>> torch.Size([10, 32])
torch.cat((t1, t2), dim=1).shape
>>>> torch.Size([4, 64])
torch.stack
주어진 텐서들을 새로운 차원으로 합침
import torch
x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]
output = torch.stack([x,y], dim=1) #[M, 2, N, K]
t1 = torch.tensor([[1,2,3],[4,5,6]])
t2 = torch.tensor([[-1,-2,-3],[-4,-5, -6]])
>>>> tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[-1, -2, -3],
[-4, -5, -6]])
torch.stack([t1,t2], dim=0) # shape: [2, 2, 3]
>>>> tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[-1, -2, -3],
[-4, -5, -6]]])
torch.stack([t1,t2], dim=1) # shape: [2, 2, 3]
>>>> tensor([[[ 1, 2, 3],
[-1, -2, -3]],
[[ 4, 5, 6],
[-4, -5, -6]]])