Skip to content

PyTorch implementation for "WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis" (DGM4MICCAI 2024)

License

Notifications You must be signed in to change notification settings

pfriedri/wdm-3d

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis

License: MIT Static Badge arXiv

This is the official PyTorch implementation of the paper WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis by Paul Friedrich, Julia Wolleb, Florentin Bieder, Alicia Durrer and Philippe C. Cattin.

If you find our work useful, please consider to ⭐ star this repository and 📝 cite our paper:

@inproceedings{friedrich2024wdm,
               title={Wdm: 3d wavelet diffusion models for high-resolution medical image synthesis},
               author={Friedrich, Paul and Wolleb, Julia and Bieder, Florentin and Durrer, Alicia and Cattin, Philippe C},
               booktitle={MICCAI Workshop on Deep Generative Models},
               pages={11--21},
               year={2024},
               organization={Springer}}

Paper Abstract

Due to the three-dimensional nature of CT- or MR-scans, generative modeling of medical images is a particularly challenging task. Existing approaches mostly apply patch-wise, slice-wise, or cascaded generation techniques to fit the high-dimensional data into the limited GPU memory. However, these approaches may introduce artifacts and potentially restrict the model's applicability for certain downstream tasks. This work presents WDM, a wavelet-based medical image synthesis framework that applies a diffusion model on wavelet decomposed images. The presented approach is a simple yet effective way of scaling diffusion models to high resolutions and can be trained on a single 40 GB GPU. Experimental results on BraTS and LIDC-IDRI unconditional image generation at a resolution of 128 x 128 x 128 show state-of-the-art image fidelity (FID) and sample diversity (MS-SSIM) scores compared to GANs, Diffusion Models, and Latent Diffusion Models. Our proposed method is the only one capable of generating high-quality images at a resolution of 256 x 256 x 256.

Dependencies

We recommend using a conda environment to install the required dependencies. You can create and activate such an environment called wdm by running the following commands:

mamba env create -f environment.yml
mamba activate wdm

Training & Sampling

For training a new model or sampling from an already trained one, you can simply adapt and use the script run.sh. All relevant hyperparameters for reproducing our results are automatically set when using the correct MODEL in the general settings. For executing the script, simply use the following command:

bash run.sh

Supported settings (set in run.sh file):

MODE: 'training', 'sampling'

MODEL: 'ours_unet_128', 'ours_unet_256', 'ours_wnet_128', 'ours_wnet_256'

DATASET: 'brats', 'lidc-idri'

Conditional Image Synthesis / Image-to-Image Translation

To use WDM for conditional image synthesis or paired image-to-image translation check out our repository pfriedri/cwdm that implements our paper cWDM: Conditional Wavelet Diffusion Models for Cross-Modality 3D Medical Image Synthesis.

Pretrained Models

We released pretrained models on HuggingFace.

Currently available models:

  • BraTS 128: BraTS, 128 x 128 x 128, U-Net backbone, 1.2M Iterations
  • LIDC-IDRI 128: LIDC-IDRI, 128 x 128 x 128, U-Net backbone, 1.2M Iterations

Data

To ensure good reproducibility, we trained and evaluated our network on two publicly available datasets:

  • BRATS 2023: Adult Glioma, a dataset containing routine clinically-acquired, multi-site multiparametric magnetic resonance imaging (MRI) scans of brain tumor patients. We just used the T1-weighted images for training. The data is available here.

  • LIDC-IDRI, a dataset containing multi-site, thoracic computed tomography (CT) scans of lung cancer patients. The data is available here.

The provided code works for the following data structure (you might need to adapt the DATA_DIR variable in run.sh):

data
└───BRATS
    └───BraTS-GLI-00000-000
        └───BraTS-GLI-00000-000-seg.nii.gz
        └───BraTS-GLI-00000-000-t1c.nii.gz
        └───BraTS-GLI-00000-000-t1n.nii.gz
        └───BraTS-GLI-00000-000-t2f.nii.gz
        └───BraTS-GLI-00000-000-t2w.nii.gz  
    └───BraTS-GLI-00001-000
    └───BraTS-GLI-00002-000
    ...

└───LIDC-IDRI
    └───LIDC-IDRI-0001
      └───preprocessed.nii.gz
    └───LIDC-IDRI-0002
    └───LIDC-IDRI-0003
    ...

We provide a script for preprocessing LIDC-IDRI. Simply run the following command with the correct path to the downloaded DICOM files DICOM_PATH and the directory you want to store the processed nifti files NIFTI_PATH:

python utils/preproc_lidc-idri.py --dicom_dir DICOM_PATH --nifti_dir NIFTI_PATH

Evaluation

As our code for evaluating the model performance has slightly different dependencies, we provide a second .yml file to set up the evaluation environment. Simply use the following command to create and activate the new environment:

mamba env create -f eval/eval_environment.yml
mamba activate eval

FID

For computing the FID score, you need to specify the following variables and use them in the command below:

  • DATASET: brats or lidc-idri
  • IMG_SIZE: 128 or 256
  • REAL_DATA_DIR: path to your real data
  • FAKE_DATA_DIR: path to your generated/ fake data
  • PATH_TO_FEATURE_EXTRACTOR: path to the feature extractor weights, e.g. ./eval/pretrained/resnet_50_23dataset.pt
  • PATH_TO_ACTIVATIONS: path to the location where you want to save mus and sigmas (in case you want to reuse them), e.g. ./eval/activations/
  • GPU_ID: gpu you want to use, e.g. 0
python eval/fid.py --dataset DATASET --img_size IMG_SIZE --data_root_real REAL_DATA_DIR --data_root_fake FAKE_DATA_DIR --pretrain_path PATH_TO_FEATURE_EXTRACTOR --path_to_activations PATH_TO_ACTIVATIONS --gpu_id GPU_ID

Mean MS-SSIM

For computing the mean MS-SSIM, you need to specify the following variables and use them in the command below:

  • DATASET: brats or lidc-idri
  • IMG_SIZE: 128 or 256
  • SAMPLE_DIR: path to the generated (or real) data
python eval/ms_ssim.py --dataset DATASET --img_size IMG_SIZE --sample_dir SAMPLE_DIR

Implementation Details for Comparing Methods

All experiments were performed on a system with an AMD Epyc 7742 CPU and a NVIDIA A100 (40GB) GPU.

TODOs

We plan to add further functionality to our framework:

  • Add compatibility for more datasets like MRNet, ADNI, or fastMRI
  • Release pre-trained models
  • Extend the framework for 3D image inpainting
  • Extend the framework for 3D image-to-image translation (pfriedri/cwdm)

Acknowledgements

Our code is based on / inspired by the following repositories:

For computing FID scores we use a pretrained model (resnet_50_23dataset.pth) from:

Thanks for making these projects open-source.

About

PyTorch implementation for "WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis" (DGM4MICCAI 2024)

Topics

Resources

License

Stars

Watchers

Forks