pytorch-generative
is a Python library which makes generative modeling in PyTorch easier by providing:
- high quality reference implementations of SOTA generative models
- useful abstractions of common building blocks found in the literature
- utilities for training, debugging, and working with Google Colab
- integration with TensorBoard for easy metrics visualization
To get started, click on one of the links below.
To install pytorch-generative
, clone the repository and install the requirements:
git clone https://www.github.com/EugenHotaj/pytorch-generative
cd pytorch-generative
pip install -r requirements.txt
After installation, run the tests to sanity check that everything works:
python -m unittest discover
All our models implement a reproduce
function with all the hyperparameters necessary to reproduce the results listed in the supported algorithms section. This makes it very easy to reproduce any results using our training script, for example:
python train.py --model image_gpt --logdir /tmp/run --use-cuda
Training metrics will periodically be logged to TensorBoard for easy visualization. To view these metrics, launch a local TensorBoard server:
tensorboard --logdir /tmp/run
To run the model on a different dataset, with different hyperparameters, etc, simply modify its reproduce
function and rerun the commands above.
To use pytorch-generative
in Google Colab, clone the repository and move it into the top-level directory:
!git clone https://www.github.com/EugenHotaj/pytorch-generative
!mv pytorch-generative/pytorch_generative .
You can then import pytorch-generative
like any other library:
import pytorch_generative as pg_nn
from pytorch_generative import models
...
Supported models are implemented as PyTorch Modules and are easy to use:
from pytorch_generative import models
... # Data loading code.
model = models.ImageGPT(in_channels=1, out_channels=1, in_size=28)
model(batch)
Alternatively, lower level building blocks in pytorch_generative.nn can be used to write models from scratch. We show how to implement a convolutional ImageGPT model below:
from torch import nn
from pytorch_generative import nn as pg_nn
class TransformerBlock(nn.Module):
"""An ImageGPT Transformer block."""
def __init__(self,
n_channels,
n_attention_heads):
"""Initializes a new TransformerBlock instance.
Args:
n_channels: The number of input and output channels.
n_attention_heads: The number of attention heads to use.
"""
super().__init__()
self._ln1 = pg_nn.NCHWLayerNorm(n_channels)
self._ln2 = pg_nn.NCHWLayerNorm(n_channels)
self._attn = pg_nn.CausalAttention(
in_channels=n_channels,
embed_channels=n_channels,
out_channels=n_channels,
n_heads=n_attention_heads,
mask_center=False)
self._out = nn.Sequential(
nn.Conv2d(
in_channels=n_channels,
out_channels=4*n_channels,
kernel_size=1),
nn.GELU(),
nn.Conv2d(
in_channels=4*n_channels,
out_channels=n_channels,
kernel_size=1))
def forward(self, x):
x = x + self._attn(self._ln1(x))
return x + self._out(self._ln2(x))
class ImageGPT(nn.Module):
"""The ImageGPT Model."""
def __init__(self,
in_channels,
out_channels,
in_size,
n_transformer_blocks=8,
n_attention_heads=4,
n_embedding_channels=16):
"""Initializes a new ImageGPT instance.
Args:
in_channels: The number of input channels.
out_channels: The number of output channels.
in_size: Size of the input images. Used to create positional encodings.
n_transformer_blocks: Number of TransformerBlocks to use.
n_attention_heads: Number of attention heads to use.
n_embedding_channels: Number of attention embedding channels to use.
"""
super().__init__()
self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
self._input = pg_nn.CausalConv2d(
mask_center=True,
in_channels=in_channels,
out_channels=n_embedding_channels,
kernel_size=3,
padding=1)
self._transformer = nn.Sequential(
*[TransformerBlock(n_channels=n_embedding_channels,
n_attention_heads=n_attention_heads)
for _ in range(n_transformer_blocks)])
self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
self._out = nn.Conv2d(in_channels=n_embedding_channels,
out_channels=out_channels,
kernel_size=1)
def forward(self, x):
x = self._input(x + self._pos)
x = self._transformer(x)
x = self._ln(x)
return self._out(x)
pytorch-generative
supports the following algorithms.
We train likelihood based models on dynamically Binarized MNIST and report the log likelihood in the tables below.
Algorithm | Binarized MNIST (nats) | Links |
---|---|---|
PixelSNAIL | 78.61 | Code, Paper |
ImageGPT | 79.17 | Code, Paper |
Gated PixelCNN | 81.50 | Code, Paper |
PixelCNN | 81.45 | Code, Paper |
MADE | 84.87 | Code, Paper |
NADE | 85.65 | Code, Paper |
FVSBN | 96.58 | Code, Paper |
NOTE: The results below are the (variational) upper bound on the negative log likelihod (or equivalently, the lower bound on the log likelihod).
Algorithm | Binarized MNIST (nats) | Links |
---|---|---|
VD-VAE | <= 80.72 | Code, Paper |
VAE | <= 86.77 | Code, Paper |
BetaVAE | N/A | Code, Paper |
VQ-VAE | N/A | Code, Paper |
VQ-VAE-2 | N/A | Code, Paper |
NOTE: Bits per dimension (bits/dim) can be calculated as (nll / 784 + log(256)) / log(2)
where 784
is the MNIST dimension,
log(256)
accounts for dequantizing pixel values, and log(2.0)
converts from natural log to base 2.
Algorithm | MNIST (bits/dim) | Links |
---|---|---|
NICE | 4.34 | Code, Paper |
Algorithm | Links |
---|---|
Mixture Models | Code, Wiki |
Kernel Density Estimators | Code, Wiki |
Nerual Style Transfer | Code, Blog, Paper |
Compositional Pattern Producing Networks | Code, Wiki |