This is an implementation of 3DiM "Novel View Synthesis with Diffusion Models" using JAX with distributed training on multiple GPUs.
3DiM is a diffusion model for 3D novel view synthesis, which is able to translate a single input view into consistent and sharp completions across many views. The core component of 3DiM is a pose-conditional image-to-image diffusion model, which takes a source view and its pose as inputs, and generates a novel view for a target pose as output. This is a basic implementation of the method with k=1 conditioning.
More details about the work can be found here.
Training is done using JAX and FLAX. The training can be distributed across multiple devices automatically if available. Since JAX does not have an inbuilt dataloader, we use torch.dataset
for data operations. The dataloader has been adopted from Scene Representation Networks.
git clone https://github.com/shiveshkhaitan/novel_view_synthesis_3d
cd novel_view_synthesis_3d
The package supports docker installation. To enable GPUs for docker, see installation guide here
docker build -f Dockerfile . -t 3dim
To start training
docker run -it --rm --memory '16g' --shm-size '16g' --gpus all \
--mount type=bind,source=$PWD,target=/home/3dim 3dim \
bash -c 'python3 train.py'
A smaller model with the following hyperparameters is available here.
ch: int = 32
ch_mult = (1, 2,)
emb_ch: int = 32
num_res_blocks: int = 2
attn_resolutions = (8, 16, 32)
attn_heads: int = 4
batch_size: 8
image_sidelength: 64
Currently this model is able to successfully denoise noisy inputs. However, it is not powerful enough to create novel views during sampling.