torch.cat
torch.cat은 주어진 텐서들을 주어진 차원에 맞춰 합쳐주는 함수
import torch
x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]
output1 = torch.cat([x,y], dim=1) #[M, N+N, K]
output2 = torch.cat([x,y], dim=2) #[M, N, K+K]
import torch
t1 = torch.rand((6, 32))
t2 = torch.rand((4, 32))
torch.cat((t1, t2), dim=0).shape
>>>> torch.Size([10, 32])
torch.cat((t1, t2), dim=1).shape
>>>> torch.Size([4, 64])
torch.stack
주어진 텐서들을 새로운 차원으로 합침
import torch
x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]
output = torch.stack([x,y], dim=1) #[M, 2, N, K]
t1 = torch.tensor([[1,2,3],[4,5,6]])
t2 = torch.tensor([[-1,-2,-3],[-4,-5, -6]])
>>>> tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[-1, -2, -3],
[-4, -5, -6]])
torch.stack([t1,t2], dim=0) # shape: [2, 2, 3]
>>>> tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[-1, -2, -3],
[-4, -5, -6]]])
torch.stack([t1,t2], dim=1) # shape: [2, 2, 3]
>>>> tensor([[[ 1, 2, 3],
[-1, -2, -3]],
[[ 4, 5, 6],
[-4, -5, -6]]])
'Pytorch' 카테고리의 다른 글
Tensor (0) | 2025.04.08 |
---|---|
딥러닝 파이프라인 (0) | 2025.04.07 |
torch.nn과 torch.nn.functional (0) | 2023.09.11 |
ResNet Image Feature Extraction (0) | 2023.09.04 |
파이토치 모델 정의 (0) | 2022.04.30 |