Official PyTorch implementation of the Detail-oriented Capsule Network (DECAPS) proposed in the paper Radiologist-Level COVID-19 Detection Using CT Scans with Detail-Oriented Capsule Networks.
Fig1. Detail-oriented Capsule Network architecture for COVID-19 Detection from CT scans
- Python (3.5 preferably; should also works fine with python 2.7)
- NumPy
- PyTorch>=1.1
- torchvision>=0.3
- Tensorflow>=1.10 (for visualizations with Tensorboard)
- Matplotlib (for saving images)
The COVID-19 CT dataset is currently available Here. By April 17th of 2020, it contains a total of 746 chest CT images, which are divided into two classes, namely COVID-19 and non-COVID-19. Instructions on preparing the data for the model will be prepared and uploaded soon.
Most of the network hyper-parameters can be found in config.py
file. You may modify them or run with
the default values which runs the DECAPS proposed in the paper.
Training the model displays the training results and saves the trained model if an improvement observed in the accuracy value.
- For training with the default setting:
python train.py
- For training with a different batch size:
python train.py --bs=16
- For running the test:
python inference.py --load_model_path=/path/to/saved/model
whereload_model_path
is the path to the desired .ckpt file.