PyTorch implementation of Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks that was originally proposed by Jost Tobias Springenberg.
Note that in this repo, only the unsupervised version was implemented for now. I reaplced the orginal architecture with DCGAN and the results are more colorful than the original one.
From 0 to 100 epochs:
- Python 2.7
- PyTorch v0.2.0
- Numpy
- SciPy
- Matplotlib
- Install PyTorh and the other dependencies
- Clone this repo:
git clone https://github.com/xinario/catgan_pytorch.git
cd catgan_pytorch
- Download the cifar10 dataset (.png format from kaggle)
- Create a dataset folder to hold the images
mkdir -p ./datasets/cifar10/images
-
Move the extracted images into the newly created folder
-
Train a model:
python catgan_cifar10.py --data_dir ./datasets/cifar10 --name cifar10
All the generated plot and samples can be found in side ./results/cifar10
optional arguments:
--continue_train to continue training from the latest checkpoints if --netG and --netD are not specified
--netG NETG path to netG (to continue training)
--netD NETD path to netD (to continue training)
--workers WORKERS number of data loading workers
--num_epochs EPOCHS number of epochs to train for
More options can be found in side the training script.
Some of code are inspired and borrowed from wgan-gp, DCGAN, catGAN chainer repo