GAN

CycleGAN

술임 2022. 5. 26. 01:05

https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

 

GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch

Image-to-Image Translation in PyTorch. Contribute to junyanz/pytorch-CycleGAN-and-pix2pix development by creating an account on GitHub.

github.com

https://arxiv.org/abs/1703.10593

 

Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

Image-to-image translation is a class of vision and graphics problems where the goal is to learn the mapping between an input image and an output image using a training set of aligned image pairs. However, for many tasks, paired training data will not be a

arxiv.org

 

Style Transfer

- 스타일 이미지가 주어졌을 때 같은 부류에 속한 것 같은 느낌을 주도록 베이스 이미지를 변환

- 스타일 이미지에 내재된 분포를 모델링하는 것이 아니라 이미지에서 스타일을 결정하는 요소만 추출해서 베이스 이미지에 주입함

- 스타일 이미지와 베이스 이미지를 보간 방법으로 합쳐서는 안됨

- 하나의 이미지를 사용하는 것이 아니라 스타일 이미지 세트 전체에서 아티스트의 스타일을 잡아냄.

- 원본 작품을 유지하면서 다른 작품의 스타일 기교로 완성하는 듯한 효과를 주어야함

 

CycleGAN

- 생성 모델링 중 스타일 트랜스퍼 분야에서 핵심적인 발전

- 샘플 쌍으로 구성된 훈련 세트 없이도 참조 이미지 세트의 스타일을 다른 이미지로 복사하는 모델을 훈련할 수 있는 방법을 보임

- pix2pix는 훈련 세트의 각 이미지가 소스와 타깃 도메인에 모두 존재해야함

- pix2pix는 한 방향으로(소스->타깃), CycleGAN은 양 방향으로 동시에 모델을 훈련함

- 실제 4개의 모델로 구성됨 : 2개의 생성자, 2개의 판별자

- 첫번째 생성자 g_AB는 도메인 A의 이미지를 도메인 B로 바꿈, 두번째 생성자 g_BA는 도메인 B의 이미지를 도메인 A로 바꿈

- 첫번째 판별자는 d_A는 도메인 A의 진짜 이미지와 생성자 g_BA가 만든 가짜 이미지를 구별할 수 있도록 훈련, 판별자 d_B는 도메인 B의 진짜 이미지와 생성자 g_AB가 만든 가짜 이미지를 구별할 수 있도록 훈련됨

- 생성자 구조는 pix2pix 논문에서는 U-Net 구조를 사용했으나, CycleGAN에서는 ResNet 구조를 사용함

gan = CycleGAN(
    input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)
    ,learning_rate = 0.0002
    , buffer_max_length = 50
    , lambda_validation = 1
    , lambda_reconstr = 10
    , lambda_id = 2
    , generator_type = 'unet'
    , gen_n_filters = 32
    , disc_n_filters = 32
    )

 

생성자(U-Net)

- 다운샘플링과 업샘플링으로 구성됨

- 다운샘플링은 입력 이미지를 공간 방향으로 압축하지만, 채널 방향으로는 확장함

- 업샘플링은 공간 방향으로 표현을 확장시키지만 채널의 수는 감소시킴

- 다운샘플링과 업샘플링 사이에는 크기가 동일한 층끼리 연결된 스킵 연결이 존재함

- 네트워크의 다운샘플링에 각 층 모델은 이미지가 무엇인지를 감지하지만 어디에 있는지 위치 정보는 잃어버림

- U-Net의 꼭짓점에 있는 특성 맵은 이미지가 무엇인지 이해할 수 있지만 어디에 있는지는 알 수 없음

- 특성 맵을 마지막 Dense 층에 연결해서 이미지에 등장하는 특정 클래스의 확률을 출력함

- 원래 U-Net 애플리케이션이나 스타일 트랜스퍼에서는 원본 이미지 크기로 다시 업샘플링 되는 것이 중요함

- 업샘플링의 각 층에서 다운샘플링 되는 동안 잃었던 공간 정보를 되돌림

- 스킵 연결은 다운샘플링 과정에서 감지된 고수준 추상 정보(이미지 스타일)을 네트워크 앞쪽 층으로부터 전달된 구체적인 공간 정보(이미지 콘텐츠)와 섞음

 

Concatenate 층

- 스킵 연결을 만들기 위한 새로운 층

- 특정 축을 따라서 여러 층을 합침

- U-Net에서는 업샘플링 층과 동일한 크기의 출력을 내는 다운샘플링 쪽의 층을 연결함

- 층을 접합하는 역할만 할 뿐 학습되는 가중치는 없음

- 채널의 수가 k개에서 2k개로 늘어남

 

InstanceNormalization 층

- 생성자는 BatchNormalization 대신 이 층을 사용함 -> 스타일 트랜스퍼에서 더 만족스러운 결과

- 배치 단위가 아니라 개별 샘플을 정규화함

- 무빙 에버리지를 위해 훈련 과정에서 계산하는 mu, sigma 파라미터가 필요하지 않음

- 각 층을 정규화하기 위해 사용되는 평균과 표준 편차는 채널별로 나누어 샘플별로 계산됨

    def build_generator_unet(self):

        def downsample(layer_input, filters, f_size=4):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = InstanceNormalization(axis = -1, center = False, scale = False)(d)
            d = Activation('relu')(d)
            
            return d

        def upsample(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same')(u)
            u = InstanceNormalization(axis = -1, center = False, scale = False)(u)
            u = Activation('relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)

            u = Concatenate()([u, skip_input])
            return u

        # Image input
        img = Input(shape=self.img_shape)

        # Downsampling 스트라이드 2인 Conv2D 층으로 이미지를 다운샘플링함
        d1 = downsample(img, self.gen_n_filters) 
        d2 = downsample(d1, self.gen_n_filters*2)
        d3 = downsample(d2, self.gen_n_filters*4)
        d4 = downsample(d3, self.gen_n_filters*8)

        # Upsampling 텐서를 업샘플링해서 원본 이미지와 같은 크기로 복원함, Concatenate 층을 포함함
        u1 = upsample(d4, d3, self.gen_n_filters*4)
        u2 = upsample(u1, d2, self.gen_n_filters*2)
        u3 = upsample(u2, d1, self.gen_n_filters)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        return Model(img, output_img)

 

 

생성자(ResNet)

- 잔차 블록을 차례대로 쌓아 구성됨

- 각 블록은 다음 층으로 출력을 전달하기 전에 입력과 출력을 합하는 스킵 연결을 가짐

def build_generator_resnet(self):

        def conv7s1(layer_input, filters, final):
            y = ReflectionPadding2D(padding =(3,3))(layer_input)
            y = Conv2D(filters, kernel_size=(7,7), strides=1, padding='valid', kernel_initializer = self.weight_init)(y)
            if final:
                y = Activation('tanh')(y)
            else:
                y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
                y = Activation('relu')(y)
            return y

        def downsample(layer_input,filters):
            y = Conv2D(filters, kernel_size=(3,3), strides=2, padding='same', kernel_initializer = self.weight_init)(layer_input)
            y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
            y = Activation('relu')(y)
            return y

        def residual(layer_input, filters):
            shortcut = layer_input
            y = ReflectionPadding2D(padding =(1,1))(layer_input)
            y = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='valid', kernel_initializer = self.weight_init)(y)
            y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
            y = Activation('relu')(y)
            
            y = ReflectionPadding2D(padding =(1,1))(y)
            y = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='valid', kernel_initializer = self.weight_init)(y)
            y = InstanceNormalization(axis = -1, center = False, scale = False)(y)

            return add([shortcut, y])

        def upsample(layer_input,filters):
            y = Conv2DTranspose(filters, kernel_size=(3, 3), strides=2, padding='same', kernel_initializer = self.weight_init)(layer_input)
            y = InstanceNormalization(axis = -1, center = False, scale = False)(y)
            y = Activation('relu')(y)
    
            return y


        # Image input
        img = Input(shape=self.img_shape)

        y = img

        y = conv7s1(y, self.gen_n_filters, False)
        y = downsample(y, self.gen_n_filters * 2)
        y = downsample(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = residual(y, self.gen_n_filters * 4)
        y = upsample(y, self.gen_n_filters * 2)
        y = upsample(y, self.gen_n_filters)
        y = conv7s1(y, 3, True)
        output = y

   
        return Model(img, output)

- ResNet은 그래디언트 소실 문제가 없음

- 층을 추가해도 모델의 정확도를 떨어뜨리지 않음

- 추가 특성이 추출되지 않으면 스킵 연결로 이전 층의 특성이 항등 사상을 통과함

--> 기본 ResNet은 스킵 연결과 더해진 후 렐루 함수를 통과시키지만 여기서는 스킵 연결 합친 후에 적용하는 활성화 함수가 없어서 이전 층의 특성 맵이 그대로 다음 층으로 전달됨

 

판별자

- 숫자 하나가 아니라 16x16 크기의 채널 하나를 가진 텐서를 출력함

PatchGAN에서 판별자 구조 승계함

- 이미지 전체에 대해 예측이 아니라 패치로 나누어서 각 패치가 진짜인지 가짜인지를 추측함 -> 판별자의 출력은 하나의 숫자가 아닌 각 패치에 대한 예측 확률을 담은 텐서가 됨

- 네트워크에 이미지를 전달하면 패치들을 한꺼번에 예측함

- 판별자의 합성곱 구조로 인해 자동으로 이미지가 패치로 나뉨

- 내용이 아니라 스타일을 기반으로 판별자가 얼마나 잘 구별하는지 손실 함수가 측정할 수 있음 -> 내용이 아니라 스타일을 사용해서 결정

def build_discriminator(self):

        def conv4(layer_input,filters, stride = 2, norm=True):
            y = Conv2D(filters, kernel_size=(4,4), strides=stride, padding='same', kernel_initializer = self.weight_init)(layer_input)
            
            if norm:
                y = InstanceNormalization(axis = -1, center = False, scale = False)(y)

            y = LeakyReLU(0.2)(y)
           
            return y

        img = Input(shape=self.img_shape)

        y = conv4(img, self.disc_n_filters, stride = 2, norm = False)
        y = conv4(y, self.disc_n_filters*2, stride = 2)
        y = conv4(y, self.disc_n_filters*4, stride = 2)
        y = conv4(y, self.disc_n_filters*8, stride = 1)

        output = Conv2D(1, kernel_size=4, strides=1, padding='same',kernel_initializer = self.weight_init)(y)

        return Model(img, output)

    def train_discriminators(self, imgs_A, imgs_B, valid, fake):

 

모델 컴파일

- 세 가지 조건으로 생성자를 동시에 평가함

1) 유효성

각 생성자에서 만든 이미지가 대응되는 판별자를 속이는가?

ex. g_BA의 출력이 d_A를 속이고, g_AB의 출력이 d_B를 속이는지?

 

2) 재구성

두 생성자를 교대로 적용하면 원본 이미지를 얻는가?

 

3) 동일성

각 생성자를 자신의 타깃 도메인에 있는 이미지에 적용 시 이미지가 바뀌지 않고 그대로 남아있는가?

 

- 각 도메인의 이미지 배치를 입력으로 받고 각 도메인에 대해 3개의 출력을 제공함

- 판별가의 가중치는 동결 -> 판별자가 모델에 관여하지만 결합된 모델은 생성자의 가중치만 훈련함

 

- 전체 손실은 각 조건에 대한 손실의 가중치 합

- 평균 제곱 오차-> 유효성 조건에 사용됨. 진짜와 가짜 타깃에 대해 판별자의 출력을 확인

 

    def compile_models(self):

        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        
        self.d_A.compile(loss='mse',
            optimizer=Adam(self.learning_rate, 0.5),
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=Adam(self.learning_rate, 0.5),
            metrics=['accuracy'])


        # Build the generators
        if self.generator_type == 'unet':
            self.g_AB = self.build_generator_unet()
            self.g_BA = self.build_generator_unet()
        else:
            self.g_AB = self.build_generator_resnet()
            self.g_BA = self.build_generator_resnet()

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[  self.lambda_validation,                       self.lambda_validation,
                                            self.lambda_reconstr, self.lambda_reconstr,
                                            self.lambda_id, self.lambda_id ],
                            optimizer=Adam(0.0002, 0.5))

        self.d_A.trainable = True
        self.d_B.trainable = True

 

CycleGAN 훈련

def train(self, data_loader, run_folder, epochs, test_A_file, test_B_file, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        # 패치마다 하나의 타깃 설정
        valid = np.ones((batch_size,) + self.disc_patch) 
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(self.epoch, epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch()):

                d_loss = self.train_discriminators(imgs_A, imgs_B, valid, fake)
                g_loss = self.train_generators(imgs_A, imgs_B, valid)

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                if batch_i % 100 == 0:
                    print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                        % ( self.epoch, epochs,
                            batch_i, data_loader.n_batches,
                            d_loss[0], 100*d_loss[7],
                            g_loss[0],
                            np.sum(g_loss[1:3]),
                            np.sum(g_loss[3:5]),
                            np.sum(g_loss[5:7]),
                            elapsed_time))

                self.d_losses.append(d_loss)
                self.g_losses.append(g_loss)

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(data_loader, batch_i, run_folder, test_A_file, test_B_file)
                    self.combined.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (self.epoch)))
                    self.combined.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
                    self.save_model(run_folder)

                
            self.epoch += 1
   def train_discriminators(self, imgs_A, imgs_B, valid, fake):

        # Translate images to opposite domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)

        self.buffer_B.append(fake_B)
        self.buffer_A.append(fake_A)

        fake_A_rnd = random.sample(self.buffer_A, min(len(self.buffer_A), len(imgs_A)))
        fake_B_rnd = random.sample(self.buffer_B, min(len(self.buffer_B), len(imgs_B)))

        # Train the discriminators (original images = real / translated = Fake)
        dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
        dA_loss_fake = self.d_A.train_on_batch(fake_A_rnd, fake)
        dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

        dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
        dB_loss_fake = self.d_B.train_on_batch(fake_B_rnd, fake)
        dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

        # Total disciminator loss
        d_loss_total = 0.5 * np.add(dA_loss, dB_loss)

        return (
            d_loss_total[0]
            , dA_loss[0], dA_loss_real[0], dA_loss_fake[0]
            , dB_loss[0], dB_loss_real[0], dB_loss_fake[0]
            , d_loss_total[1]
            , dA_loss[1], dA_loss_real[1], dA_loss_fake[1]
            , dB_loss[1], dB_loss_real[1], dB_loss_fake[1]
        )

    def train_generators(self, imgs_A, imgs_B, valid):

        # Train the generators
        return self.combined.train_on_batch([imgs_A, imgs_B],
                                                [valid, valid,
                                                imgs_A, imgs_B,
                                                imgs_A, imgs_B])

* 일반적으로 CycleGAN의 배치 크기는 1

- 생성자를 사용해서 가짜 이미지의 배치를 만들고 가짜 이미지와 진짜 이미지 배치로 각 판별자를 훈련함

- 생성자는 컴파일된 결합 모델을 통해 동시에 훈련됨

- 원본 CycleGAN에서는 동일성 손실이 선택적이고 재구성 손실과 유효성 손실이 필수적. 세 개의 손실 함수 가중치의 균형을 잘 잡는 것이 중요함

-> 동일성 손실이 너무 작으면 색이 바뀌고, 동일성 손실이 너무 크면 CycleGAN이 입력을 다른 도메인 이미지처럼 보이도록 바꾸지 못함

 

ProGAN(Progressive GAN)

- 고해상도 이미지에 GAN을 직접 훈련하는 대신 저해상도 이미지에서 생성자와 판별자를 훈련한 뒤 훈련을 진행하면서 층을 추가하여 해상도를 높임

- 먼저 추가된 층이 훈련 과정에서 동결되지 않고 전체가 훈련됨

 

SAGAN(Self-Attention GAN)

- 트랜스포머같은 순차 모델에 사용되는 어텐션 매커니즘을 이미지 생성을 위한 GAN 기반 모델에 적용함

- 어텐션을 사용하지 않는 GAN 기반 모델은 합성곱 특성 맵이 지역적인 정보만 처리할 수 있는 문제점

- 어텐션 매커니즘을 적용해서 이 문제를 해결함. 

 

BigGAN

https://arxiv.org/abs/1809.11096

 

Large Scale GAN Training for High Fidelity Natural Image Synthesis

Despite recent progress in generative image modeling, successfully generating high-resolution, diverse samples from complex datasets such as ImageNet remains an elusive goal. To this end, we train Generative Adversarial Networks at the largest scale yet at

arxiv.org

- 딥마인드에서 개발, SAGAN을 확장함

- ImageNet 데이터셋을 사용한 이미지 생성에서 최고의 성능을 냄

- 절단 기법을 사용해서(truncation trick) 생성된 샘플의 신뢰도를 높임

 

StyleGAN

- 엔비디아 연구소에서 개발

- ProGAN과 뉴럴 스타일 트랜스퍼를 사용함

- GAN 훈련 시 잠재 공간 벡터를 고수준 속성으로 구분하기 어려움

- 적응적 인스턴스 정규화를 사용