WGAN
2017. 01 Wasserstein GAN
https://arxiv.org/abs/1701.07875
Wasserstein GAN
We introduce a new algorithm named WGAN, an alternative to traditional GAN training. In this new model, we show that we can improve the stability of learning, get rid of problems like mode collapse, and provide meaningful learning curves useful for debuggi
arxiv.org
참고한 블로그 : https://jonathan-hui.medium.com/gan-wasserstein-gan-wgan-gp-6a1a2aa1b490
GAN — Wasserstein GAN & WGAN-GP
Training GAN is hard. Models may never converge and mode collapses are common. To move forward, we can make incremental improvements or…
jonathan-hui.medium.com
1) WGAN에서는 새로운 손실 함수를 소개함 -> 이진 크로스엔트로피 대신 더 안정적으로 수렴 가능
2) 최적화 과정의 안정성 향상
와서스테인 손실
- 와서스테인 손실은 1과 0 대신에 1과 -1을 사용함.
- 판별자(비평자)의 마지막 층에서 시그모이드 활성화 함수를 제거해서 예측이 [0,1]에 국한되지 않고 어떤 숫자도 가능하도록 만듦
def wasserstein(y_true, y_pred):
return -K.mean(_true * y_pred)
critic.compile(
optimizer=RMSporp(lr=0.00005)
, loss= wasserstein)
model.compile(
optimizer=RMSporp(lr=0.00005)
, loss= wasserstein)
립시츠 제약
- 와서스테인 손실은 제한이 없어 너무 큰 값일 수도 있음. 그러나 신경망에서 큰 숫자는 일반적으로 피해야함
- 비평자는 1-립시츠 연속 함수 여야 함
가중치 클리핑
- 비평자의 가중치를 작은 범위 [-0.01, 0.01] 안에 놓이도록 훈련 배치 끝난 후 가중치 클리핑을 통해 립시츠 제약을 부과함
- 매 업데이트 후에 비평자의 가중치를 클리핑함
def train_critic(self, x_train, batch_size, clip_threshold, using_generator):
valid = np.ones((batch_size,1))
fake = -np.ones((batch_size,1))
if using_generator:
true_imgs = next(x_train)[0]
if true_imgs.shape[0] != batch_size:
true_imgs = next(x_train)[0]
else:
idx = np.random.randint(0, x_train.shape[0], batch_size)
true_imgs = x_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real = self.critic.train_on_batch(true_imgs, valid)
d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -clip_threshold, clip_threshold) for w in weights]
l.set_weights(weights)
WGAN 훈련
- 와서스테인 손실 함수 사용 시 생성자가 정확히 업데이트 되도록 비평자 훈련해서 수렴시켜야함
cf. 기존 GAN은 그래디언트 소실을 피하기 위해 판별자가 너무 강해지지 않도록 해야함
- 와서스테인 손실을 사용하면 판별자와 생성자의 훈련 균형을 맞출 필요가 없음.
- 생성자를 업데이트 하는 중간에 비평자를 여러번 훈련해서 수렴에 가깝게 만들 수 있음
- 일반적으로 생성자 한 번 업데이트할 때 비평자를 다섯 번 업데이트함
def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 10
, n_critic = 5
, clip_threshold = 0.01
, using_generator = False):
for epoch in range(self.epoch, self.epoch + epochs):
for _ in range(n_critic):
d_loss = self.train_critic(x_train, batch_size, clip_threshold, using_generator)
g_loss = self.train_generator(batch_size)
# Plot the progress
print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [G loss: %.3f] " % (epoch, d_loss[0], d_loss[1], d_loss[2], g_loss))
self.d_losses.append(d_loss)
self.g_losses.append(g_loss)
# If at save interval => save generated image samples
if epoch % print_every_n_batches == 0:
self.sample_images(run_folder)
self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch)))
self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
self.save_model(run_folder)
self.epoch+=1
-> 비평자에서 가중치 클리핑해서 학습 속도가 크게 감소함
WGAN-GP
https://arxiv.org/abs/1704.00028
Improved Training of Wasserstein GANs
Generative Adversarial Networks (GANs) are powerful generative models, but suffer from training instability. The recently proposed Wasserstein GAN (WGAN) makes progress toward stable training of GANs, but sometimes can still generate only low-quality sampl
arxiv.org
WGAN에서 비평자의 정의와 컴파일 단계를 수정함
1) 비평자 손실 함수에 그래디언트 패널티 항을 포함함
2) 비평자 가중치를 클리핑하지 않음
3) 비평자에 배치 정규화 층을 사용하지 않음
그래디언트 패널티 손실
- 와서스테인 손실과 더불어서 그래디언트 패널티 손실이 추가되어 전체 손실 함수를 구성함
- 그래디언트 패널티 손실 : 입력 이미지에 대한 예측의 그래디언트 노름과 1 사이의 차이를 제곱한 것
- 모델이 자연스럽게 그래디언트 패널티 항을 최소화하는 가중치를 찾으려고 해서 모델이 립시츠 제약을 따르도록 만듦
- 일부 지점에서만 그래디언트를 계산하도록 함
def gradient_penalty_loss(self, y_true, y_pred, interpolated_samples):
"""
Computes gradient penalty based on prediction and weighted real / fake samples
"""
gradients = K.gradients(y_pred, interpolated_samples)[0]
# 보간된 이미지 입력에 대한 예측의 그래디언트를 계산함
# compute the euclidean norm by squaring ...
gradients_sqr = K.square(gradients)
# ... summing over the rows ...
gradients_sqr_sum = K.sum(gradients_sqr,
axis=np.arange(1, len(gradients_sqr.shape)))
# 벡터의 l2 노름(유클리드 거리)를 계산
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
# 이 함수는 이 L2 노름과 1 사이 거리의 제곱을 반환
gradient_penalty = K.square(1 - gradient_l2_norm)
# return the mean as loss over all the batch samples
return K.mean(gradient_penalty)
def _build_adversarial(self):
#-------------------------------
# Construct Computational Graph
# for the Critic
#-------------------------------
# 생성자의 가중치 동결
self.set_trainable(self.generator, False)
# 모델의 입력은 2개. 하나는 진짜 이미지 배치, 하나는 가짜 이미지 배치를 생성하는데 사용할 랜덤하게 생성된 숫자 배열
real_img = Input(shape=self.input_dim)
# Fake image
z_disc = Input(shape=(self.z_dim,))
fake_img = self.generator(z_disc)
# 와서스테인 손실을 계산하기 위해 진짜 이미지와 가짜 이미지를 비평자에 통과시킴
fake = self.critic(fake_img)
valid = self.critic(real_img)
# 본간된 이미지를 만들고 다시 비평자에 통과시킴
interpolated_img = RandomWeightedAverage(self.batch_size)([real_img, fake_img])
# Determine validity of weighted sample
validity_interpolated = self.critic(interpolated_img)
# 케라스의 손실 함수는 예측과 진짜 레이블 두 개 입력만 기대함.
# 파이썬의 partial 함수를 사용해서 보간된 이미지를 gradient_penalty_loss 함수에 적용한
# 사용자 정의 함수 partial_gp_loss를 정의함
partial_gp_loss = partial(self.gradient_penalty_loss,
interpolated_samples=interpolated_img)
partial_gp_loss.__name__ = 'gradient_penalty' # 케라스는 함수 이름이 필요함
# 비평자를 훈련하기 위한 모델에 2개의 입력이 정의됨.
# 하나는 진짜 이미지 배치, 하나는 가짜 이미지를 생성하는 랜덤한 입력
# 출력이 진짜 이미지는 1, 가짜 이미지는 -1, 더미 0 벡터 3개로 나옴.
# 0 벡터는 케라스의 모든 손실 함수가 반드시 출력에 매핑되어야해서 필요하지만 실제로 사용되지는 않음
# partial_gp_loss 함수에 매핑되는 더미 0 벡터를 만듦
self.critic_model = Model(inputs=[real_img, z_disc],
outputs=[valid, fake, validity_interpolated])
'''
진짜 이미지와 가짜 이미지에 대한
2개의 와서스테인 손실과 그래디언트 패널티 손실 3개의 손실 함수로 비평자를 컴파일함
그래디언트 손실에 10배 가중치 부여함
WGAN-GP 모델에 가장 잘 맞는다고 알려진 Adam 옵티마이저 사용
'''
self.critic_model.compile(
loss=[self.wasserstein,self.wasserstein, partial_gp_loss]
,optimizer=self.get_opti(self.critic_learning_rate)
,loss_weights=[1, 1, self.grad_weight]
)
#-------------------------------
# Construct Computational Graph
# for Generator
#-------------------------------
# For the generator we freeze the critic's layers
self.set_trainable(self.critic, False)
self.set_trainable(self.generator, True)
# Sampled noise for input to generator
model_input = Input(shape=(self.z_dim,))
# Generate images based of noise
img = self.generator(model_input)
# Discriminator determines validity
model_output = self.critic(img)
# Defines generator model
self.model = Model(model_input, model_output)
self.model.compile(optimizer=self.get_opti(self.generator_learning_rate)
, loss=self.wasserstein
)
self.set_trainable(self.critic, True)
'GAN' 카테고리의 다른 글
Neural Style Transfer (0) | 2022.05.29 |
---|---|
CycleGAN (0) | 2022.05.26 |
GAN 생성적 적대 신경망 (1) | 2022.05.20 |
딥러닝 (0) | 2022.05.17 |
생성 모델링 (0) | 2022.05.14 |