GAN 생성적 적대 신경망
2016년 12월. 생성적 적대 신경망 논문 발표 https://arxiv.org/abs/1701.00160v4
NIPS 2016 Tutorial: Generative Adversarial Networks
This report summarizes the tutorial presented by the author at NIPS 2016 on generative adversarial networks (GANs). The tutorial describes: (1) Why generative modeling is a topic worth studying, (2) how generative models work, and how GANs compare to other
arxiv.org
GAN
- GAN은 생성자(generator)와 판별자(discriminator) 네트워크 2개가 경쟁
- 생성자는 랜덤한 잡음을 원본 데이터셋에서 샘플링한 것처럼 보이는 샘플로 변환함
- 판별자는 원본 데이터셋에서 추출한 샘플인지 생성자가 만든 가짜인지를 구별함
- GAN은 두 네트워크를 어떻게 교대로 훈련하는지가 핵심
gan = GAN(input_dim = (28,28,1)
, discriminator_conv_filters = [64,64,128,128]
, discriminator_conv_kernel_size = [5,5,5,5]
, discriminator_conv_strides = [2,2,2,1]
, discriminator_batch_norm_momentum = None
, discriminator_activation = 'relu'
, discriminator_dropout_rate = 0.4
, discriminator_learning_rate = 0.0008
, generator_initial_dense_layer_size = (7, 7, 64)
, generator_upsample = [2,2, 1, 1]
, generator_conv_filters = [128,64, 64,1]
, generator_conv_kernel_size = [5,5,5,5]
, generator_conv_strides = [1,1, 1, 1]
, generator_batch_norm_momentum = 0.9
, generator_activation = 'relu'
, generator_dropout_rate = None
, generator_learning_rate = 0.0004
, optimiser = 'rmsprop'
, z_dim = 100
)
if mode == 'build':
gan.save(RUN_FOLDER)
else:
gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
판별자
- 지도 학습 분야의 이미지 분류 문제
- 합성곱 층을 쌓고 완전 연결 층을 출력층으로 놓은 네트워크 구조 사용 가능
- GAN 원본 논문에서는 합성곱 대신 완전 연결 층을 사용했으나 합성곱 층이 판별자의 예측 성능을 크게 높여준다고 밝혀짐
DCGAN(Digital Convolutional Generative Adbersarial Network)
def _build_discriminator(self):
### THE discriminator
discriminator_input = Input(shape=self.input_dim, name='discriminator_input')
# 판별자의 입력을 정의함, 이미지가 인풋
x = discriminator_input
for i in range(self.n_layers_discriminator):
# 합성곱 층을 차례대로 쌓음
x = Conv2D(
filters = self.discriminator_conv_filters[i]
, kernel_size = self.discriminator_conv_kernel_size[i]
, strides = self.discriminator_conv_strides[i]
, padding = 'same'
, name = 'discriminator_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
if self.discriminator_batch_norm_momentum and i > 0:
x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x)
x = self.get_activation(self.discriminator_activation)(x)
if self.discriminator_dropout_rate:
x = Dropout(rate = self.discriminator_dropout_rate)(x)
x = Flatten()(x)
# 마지막 합성곱 층을 펼쳐서 벡터로 바꿈
discriminator_output = Dense(1, activation='sigmoid', kernel_initializer = self.weight_init)(x)
# 하나의 유닛을 가진 Dense층, 시그모이드 활성화 함수는 완전 연결 층의 출력을 [0,1] 사이 범위로 변환
self.discriminator = Model(discriminator_input, discriminator_output)
# 케라스 모델로 판별자를 정의함, 이미지를 입력 받아서 0과 1 사이의 숫자 하나를 출력함
생성자
- 생성자의 입력은 일반적으로 다변수 표준 정규 분포에서 추출한 벡터
- 출력은 원본 훈련 데이터의 이미지와 동일한 크기의 이미지
- GAN의 생성자는 VAE의 디코더와 정확히 동일한 목적을 수행함, 잠재 공간의 벡터를 이미지로 변환함
- 잠재 공간의 벡터를 조작해서 원본 차원에 있는 이미지의 고수준 특성을 바꿀 수 있음
업샘플링 층
- GAN에서는 업생플링 층을 사용해서 텐서의 너비와 높이를 2배로 늘림. 입력의 각 행과 열을 반복해서 크기를 2배로 만듦
- 스트라이드 1인 보통 합성곱 층을 사용해서 합성곱 연산을 수행함
- 픽셀 사이의 공간을 0이 아니라 기존 픽셀값을 사용해서 업샘플링함
- 원본 이미지 차원으로 되돌리는데 사용할 수 있는 변환 방법
- Conv2DTranspose, Upsampling2D 중 더 잘 맞는 방식 사용
def _build_generator(self):
### THE generator
generator_input = Input(shape=(self.z_dim,), name='generator_input')
x = generator_input
x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)
if self.generator_batch_norm_momentum:
x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)
x = self.get_activation(self.generator_activation)(x)
x = Reshape(self.generator_initial_dense_layer_size)(x)
# 배치 정규화와 렐루 활성화 함수 적용 후 텐서를 7x7x64 텐서로 바꿈
if self.generator_dropout_rate:
x = Dropout(rate = self.generator_dropout_rate)(x)
for i in range(self.n_layers_generator):
# 4개의 Conv2D cmddmf xhdrhkgka
if self.generator_upsample[i] == 2:
x = UpSampling2D()(x)
x = Conv2D(
filters = self.generator_conv_filters[i]
, kernel_size = self.generator_conv_kernel_size[i]
, padding = 'same'
, name = 'generator_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
else:
x = Conv2DTranspose(
filters = self.generator_conv_filters[i]
, kernel_size = self.generator_conv_kernel_size[i]
, padding = 'same'
, strides = self.generator_conv_strides[i]
, name = 'generator_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
if i < self.n_layers_generator - 1:
if self.generator_batch_norm_momentum:
x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)
x = self.get_activation(self.generator_activation)(x)
else:
x = Activation('tanh')(x)
# tanh 활성화 함수를 사용해서 출력을 원본 이미지와 같은 [-1, 1] 범위로 변환
generator_output = x
self.generator = Model(generator_input, generator_output)
GAN 훈련
- 훈련 세트에서 진짜 샘플을 랜덤하게 선택하고 생성자의 출력을 합쳐서 훈련 세트를 만들어 판별자를 훈련함
--> 진짜 이미지의 타깃은 1이고, 생성된 이미지의 타깃은 0
--> 지도 학습 문제로 생각하면 원본 이미지와 생성된 이미지 사이의 차이점을 구분할 수 있도록 판별자를 훈련시킬 수 있을 것
- 진짜 이미지가 잠재 공간의 어떤 포인트에 매핑되는지 알려주는 훈련세트가 없어서 생성자 훈련은 판별자를 속이는 이미지를 생성함
--> 이미지가 판별자의 입력으로 주입될때 생성된 이미지가 1에 가까운 값이 출력되어야함
- 생성자를 훈련하기 위해 판별자를 연결한 케라스 모델을 만들어서 생성자의 출력 이미지를 판별자에 주입하면 판별자는 생성자의 이미지가 진짜일 확률을 출력함
- 입력은 랜덤하게 생성한 잠재 공간 벡터이고, 출력은 1인 훈련 배치를 만들어 전체 모델을 훈련함
--> 출력을 1로 지정하는 이유는 판별자가 진짜라고 생각할 수 있는 이미지를 생성자가 만드는 것이 목표이기 때문임
- 손실함수는 판별자의 출력과 타깃 1 사이의 이진 크로스엔트로피 손실임
- 전체 모델을 훈련할 때 생성자의 가중치만 업데이트되도록 판별자의 가중치를 동결하는 것이 중요함
- 판별자의 가중치를 동결하지 않으면 생성된 이미지를 진짜라고 여기도록 조정되기 때문임. 판별자가 약하기 때문이 아니라 생성자가 강하기 때문에 생성된 이미지가 진짜 이미지 1에 가까운 값으로 예측되어야함
GAN 컴파일
def _build_adversarial(self):
### 판별자 컴파일
self.discriminator.compile(
optimizer=self.get_opti(self.discriminator_learning_rate)
, loss = 'binary_crossentropy'
'''
타깃이 이진 값이고 시그모이드 활성화 함수를 가진 하나의 출력 유닛을 사용하기 때문에
판별자를 이진 크로스엔트로피로 컴파일함
'''
, metrics = ['accuracy']
)
### 생성자를 훈련하기 위해 모델 컴파일
self.set_trainable(self.discriminator, False)
# 판별자의 가중치를 동결함. 컴파일한 판별자 모델이 영향 받지 않음
model_input = Input(shape=(self.z_dim,), name='model_input')
model_output = self.discriminator(self.generator(model_input))
self.model = Model(model_input, model_output)
'''
100차원 잠재 공간 벡터를 입력으로 받는 새로운 모델을 정의함
이 벡터가 생성자와 동결한 판별자를 통과하여 확률이 출력됨
'''
self.model.compile(optimizer=self.get_opti(self.generator_learning_rate) , loss='binary_crossentropy', metrics=['accuracy'])
'''
이진 크로스 엔트로피 손실을 사용해서 전체 모델을 컴파일함
일반적으로 판별자가 생성자가 강해야하므로
학습률이 판별자보다 느림
학습률은 주의 깊게 튜닝해야하는 파라미터임
'''
self.set_trainable(self.discriminator, True)
GAN 훈련
def train_discriminator(self, x_train, batch_size, using_generator):
valid = np.ones((batch_size,1))
fake = np.zeros((batch_size,1))
# 진짜 이미지로 훈련
if using_generator:
true_imgs = next(x_train)[0]
if true_imgs.shape[0] != batch_size:
true_imgs = next(x_train)[0]
else:
idx = np.random.randint(0, x_train.shape[0], batch_size)
true_imgs = x_train[idx]
# 생성된 이미지로 훈련
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real, d_acc_real = self.discriminator.train_on_batch(true_imgs, valid)
# 판별자의 훈련은 타깃 1과 진짜 이미지로 배치를 만들어 수행
d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
# 타깃 0과 생성된 이미지로 배치 훈련
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_acc = 0.5 * (d_acc_real + d_acc_fake)
return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake]
def train_generator(self, batch_size):
valid = np.ones((batch_size,1))
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
return self.model.train_on_batch(noise, valid)
# 생성자의 훈련은 타깃 1과 생성된 이미지로 배치를 만들어 수행
# 판별자는 동결되었기 때문에 가중치가 변하지 않음
# 조금 더 판별자를 속일 수 있는 이미지를 생성할 수 있도록 생성자의 가중치가 이동함
def train(self, x_train, batch_size, epochs, run_folder
, print_every_n_batches = 50
, using_generator = False):
for epoch in range(self.epoch, self.epoch + epochs):
d = self.train_discriminator(x_train, batch_size, using_generator)
g = self.train_generator(batch_size)
print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch, d[0], d[1], d[2], d[3], d[4], d[5], g[0], g[1]))
self.d_losses.append(d)
self.g_losses.append(g)
if epoch % print_every_n_batches == 0:
self.sample_images(run_folder)
self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch)))
self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
self.save_model(run_folder)
self.epoch += 1
- 적절한 횟수의 에폭 반복 후 판별자와 생성자가 평형을 이루어 생성자가 판별자로부터 유용한 정보를 학습하고 이미지 품질이 향상됨
- 훈련 세트에서 이미지를 단순히 재생성해서는 안되고 생성된 특정 샘플에 가장 비슷한 훈련 세트의 이미지를 찾아야함
-> 두 이미지 사이의 거리를 재는 L1 norm
def l1_compare_images(img1, img2):
return np.mean(np.abs(img1-img2))
GAN 문제점
- GAN은 훈련이 어렵기로 유명
진동 손실
- 판별자와 생성자의 손실이 장기간 안정된 모습을 보여주지 못하고 큰 폭으로 진동하기 시작함
모드 붕괴
- 생성자가 판별자를 속이는 적은 수의 샘플을 찾을 때 발생
- 한정된 이 샘플 이외에는 다른 샘플을 생성하지 못함
- 손실 함수의 그래디언트가 0에 가까운 값으로 무너짐
유용하지 않은 손실
- 생성자는 현재 판별자에 의해서만 평가되고 판별자는 계속 향상되므로 훈련 과정의 다른 지점에서 평가된 손실 비교는 어려움
- 이미지 품질이 향상됨에도 불구하고 생성자의 손실 함수는 증가할 수 있음
하이퍼파라미터
- 간단한 GAN이라도 튜닝해야할 하이퍼파라미터의 개수가 상당히 많음
- 간은 파라미터의 작은 변화에도 매우 민감함
- 계획적인 시행착오를 거쳐 잘 맞는 파라미터 조합을 찾는 경우가 많음
WGAN, WGAN-GP는 GAN 훈련하는 데 최선의 방법으로 간주됨