torch.cat torch.cat은 주어진 텐서들을 주어진 차원에 맞춰 합쳐주는 함수 import torch x = torch.rand(batch_size, N, K) # [M, N, K] y = torch.rand(batch_size, N, K) # [M, N, K] output1 = torch.cat([x,y], dim=1) #[M, N+N, K] output2 = torch.cat([x,y], dim=2) #[M, N, K+K] import torch t1 = torch.rand((6, 32)) t2 = torch.rand((4, 32)) torch.cat((t1, t2), dim=0).shape >>>> torch.Size([10, 32]) torch.cat((t1, t2), dim=1).sha..