GAN

GAN 생성적 적대 신경망

술임 2022. 5. 20. 00:24

2016년 12월. 생성적 적대 신경망 논문 발표 https://arxiv.org/abs/1701.00160v4

 

NIPS 2016 Tutorial: Generative Adversarial Networks

This report summarizes the tutorial presented by the author at NIPS 2016 on generative adversarial networks (GANs). The tutorial describes: (1) Why generative modeling is a topic worth studying, (2) how generative models work, and how GANs compare to other

arxiv.org

 

GAN

- GAN은 생성자(generator)와 판별자(discriminator) 네트워크 2개가 경쟁

- 생성자는 랜덤한 잡음을 원본 데이터셋에서 샘플링한 것처럼 보이는 샘플로 변환함

- 판별자는 원본 데이터셋에서 추출한 샘플인지 생성자가 만든 가짜인지를 구별함

- GAN은 두 네트워크를 어떻게 교대로 훈련하는지가 핵심

 

이미지 출처 : https://sites.google.com/site/aidysft/generativeadversialnetwork

 

gan = GAN(input_dim = (28,28,1)
        , discriminator_conv_filters = [64,64,128,128]
        , discriminator_conv_kernel_size = [5,5,5,5]
        , discriminator_conv_strides = [2,2,2,1]
        , discriminator_batch_norm_momentum = None
        , discriminator_activation = 'relu'
        , discriminator_dropout_rate = 0.4
        , discriminator_learning_rate = 0.0008
        , generator_initial_dense_layer_size = (7, 7, 64)
        , generator_upsample = [2,2, 1, 1]
        , generator_conv_filters = [128,64, 64,1]
        , generator_conv_kernel_size = [5,5,5,5]
        , generator_conv_strides = [1,1, 1, 1]
        , generator_batch_norm_momentum = 0.9
        , generator_activation = 'relu'
        , generator_dropout_rate = None
        , generator_learning_rate = 0.0004
        , optimiser = 'rmsprop'
        , z_dim = 100
        )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

 

판별자

- 지도 학습 분야의 이미지 분류 문제

- 합성곱 층을 쌓고 완전 연결 층을 출력층으로 놓은 네트워크 구조 사용 가능

- GAN 원본 논문에서는 합성곱 대신 완전 연결 층을 사용했으나 합성곱 층이 판별자의 예측 성능을 크게 높여준다고 밝혀짐

DCGAN(Digital Convolutional Generative Adbersarial Network)

def _build_discriminator(self):

        ### THE discriminator
        discriminator_input = Input(shape=self.input_dim, name='discriminator_input')
		# 판별자의 입력을 정의함, 이미지가 인풋

        x = discriminator_input

        for i in range(self.n_layers_discriminator):
		# 합성곱 층을 차례대로 쌓음
        
            x = Conv2D(
                filters = self.discriminator_conv_filters[i]
                , kernel_size = self.discriminator_conv_kernel_size[i]
                , strides = self.discriminator_conv_strides[i]
                , padding = 'same'
                , name = 'discriminator_conv_' + str(i)
                , kernel_initializer = self.weight_init
                )(x)

            if self.discriminator_batch_norm_momentum and i > 0:
                x = BatchNormalization(momentum = self.discriminator_batch_norm_momentum)(x)

            x = self.get_activation(self.discriminator_activation)(x)

            if self.discriminator_dropout_rate:
                x = Dropout(rate = self.discriminator_dropout_rate)(x)

        x = Flatten()(x)
        # 마지막 합성곱 층을 펼쳐서 벡터로 바꿈
        
        discriminator_output = Dense(1, activation='sigmoid', kernel_initializer = self.weight_init)(x)
		# 하나의 유닛을 가진 Dense층, 시그모이드 활성화 함수는 완전 연결 층의 출력을 [0,1] 사이 범위로 변환
        
        self.discriminator = Model(discriminator_input, discriminator_output)
        # 케라스 모델로 판별자를 정의함, 이미지를 입력 받아서 0과 1 사이의 숫자 하나를 출력함

 

생성자

- 생성자의 입력은 일반적으로 다변수 표준 정규 분포에서 추출한 벡터

- 출력은 원본 훈련 데이터의 이미지와 동일한 크기의 이미지

- GAN의 생성자는 VAE의 디코더와 정확히 동일한 목적을 수행함, 잠재 공간의 벡터를 이미지로 변환함

- 잠재 공간의 벡터를 조작해서 원본 차원에 있는 이미지의 고수준 특성을 바꿀 수 있음

 

업샘플링 층

- GAN에서는 업생플링 층을 사용해서 텐서의 너비와 높이를 2배로 늘림. 입력의 각 행과 열을 반복해서 크기를 2배로 만듦

- 스트라이드 1인 보통 합성곱 층을 사용해서 합성곱 연산을 수행함

- 픽셀 사이의 공간을 0이 아니라 기존 픽셀값을 사용해서 업샘플링함

- 원본 이미지 차원으로 되돌리는데 사용할 수 있는 변환 방법

- Conv2DTranspose, Upsampling2D 중 더 잘 맞는 방식 사용

 

def _build_generator(self):

        ### THE generator

        generator_input = Input(shape=(self.z_dim,), name='generator_input')

        x = generator_input

        x = Dense(np.prod(self.generator_initial_dense_layer_size), kernel_initializer = self.weight_init)(x)

        if self.generator_batch_norm_momentum:
            x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)

        x = self.get_activation(self.generator_activation)(x)

        x = Reshape(self.generator_initial_dense_layer_size)(x)
		# 배치 정규화와 렐루 활성화 함수 적용 후 텐서를 7x7x64 텐서로 바꿈

        if self.generator_dropout_rate:
            x = Dropout(rate = self.generator_dropout_rate)(x)

        for i in range(self.n_layers_generator):
		# 4개의 Conv2D cmddmf xhdrhkgka
        
            if self.generator_upsample[i] == 2:
                x = UpSampling2D()(x)
                x = Conv2D(
                    filters = self.generator_conv_filters[i]
                    , kernel_size = self.generator_conv_kernel_size[i]
                    , padding = 'same'
                    , name = 'generator_conv_' + str(i)
                    , kernel_initializer = self.weight_init
                )(x)
            else:

                x = Conv2DTranspose(
                    filters = self.generator_conv_filters[i]
                    , kernel_size = self.generator_conv_kernel_size[i]
                    , padding = 'same'
                    , strides = self.generator_conv_strides[i]
                    , name = 'generator_conv_' + str(i)
                    , kernel_initializer = self.weight_init
                    )(x)

            if i < self.n_layers_generator - 1:

                if self.generator_batch_norm_momentum:
                    x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)

                x = self.get_activation(self.generator_activation)(x)
                    
                
            else:

                x = Activation('tanh')(x)
                # tanh 활성화 함수를 사용해서 출력을 원본 이미지와 같은 [-1, 1] 범위로 변환


        generator_output = x

        self.generator = Model(generator_input, generator_output)

 

GAN 훈련

- 훈련 세트에서 진짜 샘플을 랜덤하게 선택하고 생성자의 출력을 합쳐서 훈련 세트를 만들어 판별자를 훈련함

--> 진짜 이미지의 타깃은 1이고, 생성된 이미지의 타깃은 0

--> 지도 학습 문제로 생각하면 원본 이미지와 생성된 이미지 사이의 차이점을 구분할 수 있도록 판별자를 훈련시킬 수 있을 것

- 진짜 이미지가 잠재 공간의 어떤 포인트에 매핑되는지 알려주는 훈련세트가 없어서 생성자 훈련은 판별자를 속이는 이미지를 생성함

--> 이미지가 판별자의 입력으로 주입될때 생성된 이미지가 1에 가까운 값이 출력되어야함

- 생성자를 훈련하기 위해 판별자를 연결한 케라스 모델을 만들어서 생성자의 출력 이미지를 판별자에 주입하면 판별자는 생성자의 이미지가 진짜일 확률을 출력함

- 입력은 랜덤하게 생성한 잠재 공간 벡터이고, 출력은 1인 훈련 배치를 만들어 전체 모델을 훈련함

--> 출력을 1로 지정하는 이유는 판별자가 진짜라고 생각할 수 있는 이미지를 생성자가 만드는 것이 목표이기 때문임

- 손실함수는 판별자의 출력과 타깃 1 사이의 이진 크로스엔트로피 손실임

- 전체 모델을 훈련할 때 생성자의 가중치만 업데이트되도록 판별자의 가중치를 동결하는 것이 중요함

- 판별자의 가중치를 동결하지 않으면 생성된 이미지를 진짜라고 여기도록 조정되기 때문임. 판별자가 약하기 때문이 아니라 생성자가 강하기 때문에 생성된 이미지가 진짜 이미지 1에 가까운 값으로 예측되어야함

 

GAN 컴파일

def _build_adversarial(self):
        
        ### 판별자 컴파일

        self.discriminator.compile(
        optimizer=self.get_opti(self.discriminator_learning_rate)  
        , loss = 'binary_crossentropy' 
        '''
        타깃이 이진 값이고 시그모이드 활성화 함수를 가진 하나의 출력 유닛을 사용하기 때문에 
        판별자를 이진 크로스엔트로피로 컴파일함
        '''
        ,  metrics = ['accuracy']
        )
        
        ### 생성자를 훈련하기 위해 모델 컴파일

        self.set_trainable(self.discriminator, False)
        # 판별자의 가중치를 동결함. 컴파일한 판별자 모델이 영향 받지 않음

        model_input = Input(shape=(self.z_dim,), name='model_input')
        model_output = self.discriminator(self.generator(model_input))
        self.model = Model(model_input, model_output)
        '''
        100차원 잠재 공간 벡터를 입력으로 받는 새로운 모델을 정의함
        이 벡터가 생성자와 동결한 판별자를 통과하여 확률이 출력됨
        '''

        self.model.compile(optimizer=self.get_opti(self.generator_learning_rate) , loss='binary_crossentropy', metrics=['accuracy'])
		'''
        이진 크로스 엔트로피 손실을 사용해서 전체 모델을 컴파일함
        일반적으로 판별자가 생성자가 강해야하므로
        학습률이 판별자보다 느림
        학습률은 주의 깊게 튜닝해야하는 파라미터임
        '''
        
        self.set_trainable(self.discriminator, True)

 

GAN 훈련

    def train_discriminator(self, x_train, batch_size, using_generator):

        valid = np.ones((batch_size,1))
        fake = np.zeros((batch_size,1))
		
        # 진짜 이미지로 훈련
        if using_generator:
            true_imgs = next(x_train)[0]
            if true_imgs.shape[0] != batch_size:
                true_imgs = next(x_train)[0]
        else:
            idx = np.random.randint(0, x_train.shape[0], batch_size)
            true_imgs = x_train[idx]
        
        # 생성된 이미지로 훈련
        noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        gen_imgs = self.generator.predict(noise)

        d_loss_real, d_acc_real =   self.discriminator.train_on_batch(true_imgs, valid)
        # 판별자의 훈련은 타깃 1과 진짜 이미지로 배치를 만들어 수행
        
        d_loss_fake, d_acc_fake =   self.discriminator.train_on_batch(gen_imgs, fake)
        # 타깃 0과 생성된 이미지로 배치 훈련
        
        d_loss =  0.5 * (d_loss_real + d_loss_fake)
        d_acc = 0.5 * (d_acc_real + d_acc_fake)

        return [d_loss, d_loss_real, d_loss_fake, d_acc, d_acc_real, d_acc_fake]

    def train_generator(self, batch_size):
        valid = np.ones((batch_size,1))
        noise = np.random.normal(0, 1, (batch_size, self.z_dim))
        return self.model.train_on_batch(noise, valid)
        # 생성자의 훈련은 타깃 1과 생성된 이미지로 배치를 만들어 수행
        # 판별자는 동결되었기 때문에 가중치가 변하지 않음
        # 조금 더 판별자를 속일 수 있는 이미지를 생성할 수 있도록 생성자의 가중치가 이동함


    def train(self, x_train, batch_size, epochs, run_folder
    , print_every_n_batches = 50
    , using_generator = False):

        for epoch in range(self.epoch, self.epoch + epochs):

            d = self.train_discriminator(x_train, batch_size, using_generator)
            g = self.train_generator(batch_size)

            print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [D acc: (%.3f)(%.3f, %.3f)] [G loss: %.3f] [G acc: %.3f]" % (epoch, d[0], d[1], d[2], d[3], d[4], d[5], g[0], g[1]))

            self.d_losses.append(d)
            self.g_losses.append(g)

            if epoch % print_every_n_batches == 0:
                self.sample_images(run_folder)
                self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch)))
                self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
                self.save_model(run_folder)

            self.epoch += 1

- 적절한 횟수의 에폭 반복 후 판별자와 생성자가 평형을 이루어 생성자가 판별자로부터 유용한 정보를 학습하고 이미지 품질이 향상됨

- 훈련 세트에서 이미지를 단순히 재생성해서는 안되고 생성된 특정 샘플에 가장 비슷한 훈련 세트의 이미지를 찾아야함

-> 두 이미지 사이의 거리를 재는 L1 norm  

def l1_compare_images(img1, img2):
	return np.mean(np.abs(img1-img2))

 

GAN 문제점

- GAN은 훈련이 어렵기로 유명

 

진동 손실

- 판별자와 생성자의 손실이 장기간 안정된 모습을 보여주지 못하고 큰 폭으로 진동하기 시작함

 

모드 붕괴

- 생성자가 판별자를 속이는 적은 수의 샘플을 찾을 때 발생

- 한정된 이 샘플 이외에는 다른 샘플을 생성하지 못함

- 손실 함수의 그래디언트가 0에 가까운 값으로 무너짐

 

유용하지 않은 손실

- 생성자는 현재 판별자에 의해서만 평가되고 판별자는 계속 향상되므로 훈련 과정의 다른 지점에서 평가된 손실 비교는 어려움

- 이미지 품질이 향상됨에도 불구하고 생성자의 손실 함수는 증가할 수 있음

 

하이퍼파라미터

- 간단한 GAN이라도 튜닝해야할 하이퍼파라미터의 개수가 상당히 많음

- 간은 파라미터의 작은 변화에도 매우 민감함

- 계획적인 시행착오를 거쳐 잘 맞는 파라미터 조합을 찾는 경우가 많음

 

WGAN, WGAN-GP는 GAN 훈련하는 데 최선의 방법으로 간주됨