Tensorflow implementation of GP-GAN: Towards Realistic High-Resolution Image Blending by Hui-Kai Wu et al. You can find the original implementation in Chainer here.
Currently, only the supervised learning approach is implemented
Code tested in python 3.6.0 and Tensorflow 1.15.0.
Install the requirements with:
pip install -r requirements.txt
(To skip writing data paths in the following steps, place the downloaded Transient Attributes Dataset under 'DataBase/TransientAttributes/imageAlignedLD' and use the default paths )
- Download Transient Attributes Dataset here
- Crop the images using the bounding boxes at DataBase/TransientAttributes/bbox.txt
python crop_aligned_images.py --data_root "path to imageAlignedLD"
- From the cropped images, create copy-paste images and write TFRecords
python write_tf_records.py --dataset_dir "path to cropped_images"
- Train blending GAN
To resume training from a checkpoint use the flag
python train_blending_gan.py --train_data_root "path to train.tfrecords" --val_data_root "path to val.tfrecords" --save_folder "output path" --experiment "experiment name"
--weights_path "path to .ckpt"
Run TensorBoard:
tensorboard --logdir="path to tensorboard folder"
Under Scalars you will find the training graphs. The X axis represents cycle (each cycle includes N discriminator steps and 1 generator step).
- Training_disc_loss: Discriminator loss. Calculated as train_disc_value_real - train_disc_value_fake.
- Train_disc_value_fake: Output value from the discriminator for fake (generated) images.
- Train_disc_value_real: Output value from the discriminator for real images.
- Train_gen_disc_component: From the generator loss, the component representing how much you fool the discriminator.
- Train_gen_l2_component: From the generator loss, the component representing the l2 difference between generated image and background (destination) image.
- Train_gen_loss: Total generator loss as the sum of the previous 2 components.
- Val_disc_loss: Validation discriminator loss.
- Val_gen_loss: Validation generator loss.
Notice that resuming training from a checkpoint that's older than the last logged value in tensorboard can result in a graph that goes back and forward again.
Under images you will find image samples from the training process:
After training the GAN, we can load the weights and blend images. The supported input is a source image, a destination image and a binary mask.
python run_gp_gan.py --src_image images/test_images/src.jpg --dst_image images/test_images/dst.jpg --mask_image images/test_images/mask.png --blended_image images/test_images/result.png --generator_path "path to .ckpt"
Source | Destination | Mask | Composited | Blended |
---|---|---|---|---|