ML

Batch Normalization

술임 2023. 1. 18. 21:32

배치 정규화 batch normalization

- 샘플의 분포가 바뀌는 현상이 공변량 시프트 covariate shift

- 공변량 시프트 현상은 층이 깊어질수록 심각해짐 --> 학습에 부정적인 영향

- 배치 정규화는 모든 층에 정규화를 독립적으로 적용하는 기법

--> 학습 과정에서 각 배치별로 데이터가 다양한 분포를 가져도 평균과 분산을 이용해서 정규화 시행

- 훈련 집합 전체에 적용하는 것보다 미니배치 단위로 적용하는 것이 나아서 배치 정규화 사용하게됨

- 배치 정규화를 적용하지 않으면 활성함수를 적용한걸 다음 층으로 넘기면 됨

--> 배치 정규화에서는 미니 배치 단위로 중간 결과를 모은 후 평균과 분산을 사용해서 정규화 변환값을 구함

--> 정규화 변환값을 선형 변환(노드마다 고유한 매개변수  로 학습을 통해 알아냄)을 통해 얻은 값을 활성함수에 입력함

- 컨볼루션 신경망에서는 노드 단위가 아니라 특징 맵 단위로 코드 1, 2를 적용함

 

배치 정규화의 효과

1) 매개변수의 초깃값에 덜 민감함

--> 초기화 전략에 크게 구애받지 않음

2) 학습률을 크게 설정해서 수렴 속도 향상 가능

--> 배치 정규화 적용 시 학습 시간이 빨라짐

--> feature가 동일한 스케일이 되어 학습률 결정에 유리해짐

e.g. feature의 스케일이 다르면 gradient가 다르게 되어 같은 학습률에 대해 weight마다 반응 정도가 달라짐

3) 깊은 신경망의 학습이 가능함

4) 규제 효과 제공

- ResNet은 배치 정규화를 사용하여 드롭아웃 규제 기법 없이 높은 성능을 얻어냄

 

배치 정규화의 한계

- batch의 크기에 영향을 받음

--> batch가 너무 작거나 크면 잘 동작하지 않음

- 개선을 위해 weight normalization이나 layer normalization이 사용되기도 함

 

pytorch에서의 적용

# With Learnable Parameters
m = nn.BatchNorm1d(100)
input = torch.randn(20, 100)
output = m(input)

# With Learnable Parameters
m = nn.BatchNorm2d(100)
input = torch.randn(20, 100, 35, 45)
output = m(input)

- torch.nn.BatchNorm1d, torch.nn.BatchNorm2d 사용

--> 을 초기값으로 학습을 시작

- channel 수를 맞추면 batch normalization 연산이 가능해짐

 

참고 

https://gaussian37.github.io/dl-concept-batchnorm/