[Pytorch] squeeze, unsqueeze 란?
Pytorch를 이용해서 DL 프레임워크를 공부하다보면 나오는 squeeze, unsqueeze 함수가 있습니다. 과연 이것이 무엇인지 확인해 보겠습니다.
수학적인 내용도 필요하고 이러한 함수가 왜 필요한지 알아봅시다.
📌 squeeze와 unsqueeze 함수의 역할
1. squeeze 함수
'squeeze()' 함수는 tensor의 차원 중에서 크기가 1인 차원을 제거합니다. 다시 말하면? 불필요한 1차원의 축을 없애주는 것입니다.
tensor_a = torch.zeros(1, 3, 1, 5)
print(tensor_a.shape)
>> (1, 3, 1, 5)
tensor_b = tensor._a.squeeze()
print(tensor_b.shape)
>> (3, 5)
이렇게 tensor_a에 1차원이 사라져 결과값이 변경됩니다.
2. unsqueeze 함수
'unsqueeze()' 함수는 원하는 위치에 새로운 크기가 1인 차원을 추가합니다.
tensor_c = tensor.zeros(3, 5)
print(tensor_c.shape)
>> (3,5)
tensor_d = tensor_c.unsqueeze(0)
print(tensor_d.shape)
>> (1, 3, 5)
이렇게 0번째 축에 1차원이 추가되어 (1, 3, 5)가 됩니다.
중요한 점은 그럼 이 함수가 왜 필요하냐겠죠?
📌 squeeze와 unsqueeze 함수가 필요한 이유
1. Broadcasting 연산 지원
Pytorch를 비롯한 딥러닝 프레임워크는 연산 시 Broadcating이라는 기법을 사용하여 자동으로 차원은 맞춰줍니다. 이 때 연산이 정상적으로 수행되도록 의도적으로 차원을 추가하거나 제거하는 작업이 필요합니다.
Broadcasting은 서로 다른 크기의 tensor를 연산할 때 크기를 자동으로 확장하여 연산이 가능하게 만드는 기법입니다. 당연하게도 tensor 간의 차원이 서로 맞아야하기 때문에 이를 위해 의도적으로 차원을 추가하거나 제거하는 과정이 필요합니다.
예를 들어, 크기가 '(3, 4)'인 tensor와 '(4, )'인 tensor를 더할 때 '(4, )' 텐서를 (1, 4)로 차원을 추가하면 Broadcasting을 통해 두 tensor를 연산할 수 있습니다.
예를 통해 확인해봅시다.
a = torch.ones(3, 4)
b = torch.rand(4, 1)
print(a, b)
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
tensor([[0.9412],
[0.3585],
[0.8413],
[0.7319]])
이렇게 a는 '(3, 4)'인 1로 이루어진 tensor를, b는 랜덤으로 '(4, 1)'인 tensor를 만들었습니다. 이 들을 그냥 'a + b'와 같이 연산한다면 아래와 같은 에러가 반환됩니다.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[12], line 1
----> 1 a + b
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 0
차원이 다르기 때문에 발생하는 오류입니다. 해결하기 위해선 '(4, 1)'를 연산이 가능한 단일 차원인 '(4)'로 만드는 것도 방법입니다. 이때 필요없는 1차원을 제거하기 위해 squeeze()를 사용하면 아래와 같이 연산이 가능합니다
a + b.squeeze()
tensor([[1.9412, 1.3585, 1.8413, 1.7319],
[1.9412, 1.3585, 1.8413, 1.7319],
[1.9412, 1.3585, 1.8413, 1.7319]])
2. 모델 입력 데이터 차원 맞추기
Pytorch에서 대부분의 딥러닝 모델은 입력 데이터로 batch(배치) 자원을 요구합니다. 배치 자원은 여러 데이터를 묶어서 한 번에 처리하기 위한 차원이고 즉, 입력 데이터가 단 하나(single sample)일지라도 모델의 구조상 일관된 입력 형태를 유지하기 위해 배치 차원을 명시적으로 추가해야 합니다.
예를 들어, 이미지 분류 모델의 경우 입력이 보통 '(배치 크기, 채널, 높이, 너비)'의 형태로 정의 됩니다. 이 때, 단일 이미지를 모델에 넣으려면 '(채널, 높이, 너비)'형태를 '(1, 채널, 높이, 너비)' 형태로 변환해야 합니다. 아래와 같이 말이죠
single_image = torch.rand(3, 224, 224) # 단일 이미지 데이터
input_batch = single_image.unsqueeze(0) # 배치 차원을 추가하여 (1, 3, 224, 224)
이러한 차원의 변형이 없으면 모델에서 차원 불일치 오류가 발생하며 정삭적으로 동작하지 않으므로, 명확한 차원 관리가 필요합니다.(이는 배치 뿐 아니라 1차원이 필요한 경우 적용이 될 것입니다.)
3. 모델 결과 차원 간소화
모델의 출력 결과가 필요 이상으로 많은 차원을 가지고 있는 경우, 결과를 직관적으로 해석하고 시각화하기 위해 squeeze 함수를 활용하여 불필요한 차원을 제거할 수 있습니다.