Pytorch

torch.cat과 torch.stack

술임 2023. 11. 27. 14:56

torch.cat

torch.cat은 주어진 텐서들을 주어진 차원에 맞춰 합쳐주는 함수

 

 

import torch

x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]

output1 = torch.cat([x,y], dim=1) #[M, N+N, K]
output2 = torch.cat([x,y], dim=2) #[M, N, K+K]
import torch

t1 = torch.rand((6, 32))
t2 = torch.rand((4, 32))

torch.cat((t1, t2), dim=0).shape
>>>> torch.Size([10, 32])

torch.cat((t1, t2), dim=1).shape
>>>> torch.Size([4, 64])

torch.stack

주어진 텐서들을 새로운 차원으로 합침

import torch

x = torch.rand(batch_size, N, K) # [M, N, K]
y = torch.rand(batch_size, N, K) # [M, N, K]

output = torch.stack([x,y], dim=1) #[M, 2, N, K]
t1 = torch.tensor([[1,2,3],[4,5,6]])
t2 = torch.tensor([[-1,-2,-3],[-4,-5, -6]])

>>>> tensor([[1, 2, 3],
             [4, 5, 6]]) 
     tensor([[-1, -2, -3],
             [-4, -5, -6]])
            
torch.stack([t1,t2], dim=0)  # shape: [2, 2, 3]
>>>> tensor([[[ 1,  2,  3],
              [ 4,  5,  6]],

             [[-1, -2, -3],
              [-4, -5, -6]]])
             
 torch.stack([t1,t2], dim=1)   # shape: [2, 2, 3]
 >>>> tensor([[[ 1,  2,  3],
               [-1, -2, -3]],

              [[ 4,  5,  6],
               [-4, -5, -6]]])

'Pytorch' 카테고리의 다른 글

Tensor  (0) 2025.04.08
딥러닝 파이프라인  (0) 2025.04.07
torch.nn과 torch.nn.functional  (0) 2023.09.11
ResNet Image Feature Extraction  (0) 2023.09.04
파이토치 모델 정의  (0) 2022.04.30