This is the official PyTorch implementation for the gCIS.
Towards A General CT Image Segmentation (gCIS) Model for Anatomical Structures and Lesions
by Xi Ouyang, Dongdong Gu, Xuejian Li, Wenqi Zhou, QianQian Chen, Yiqiang
Zhan, Xiang Sean Zhou, Feng Shi, Zhong Xue, and Dinggang Shen
We propose a general medical image segmentation (gCIS) model for computerized tomography (CT) images, capable of performing a wide range of segmentation tasks.
This code requires PyTorch 1.9+ (not well test on PyTorch 2.0, so please use the Pytorch 1.9+ version) and python 3.6+. If using Anaconda, you can create a new environmnet for gCIS
conda create -n gcis python=3.8
conda activate gcis
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia # please refer to the pytorch official document
Install other dependencies (e.g., monai, nibabel, open_clip_torch) by
pip install -r requirements.txt
The installation of the opensource toolkits can take from a few minutes to tens of minutes, depending on the Internet speed
Due to the commercial and privacy issue from our partners, we are working on releasing partial in-house data with the consent of everyone. Before that, you can try our model in the public datasets used in our paper:
- Head and neck organ-at-risk CT & MR segmentation (HaN-Seg) Challenge
- The Cancer Imaging Archive, TCIA
- Medical Segmentation Decathlon, MSD
- Segmentation of Organs-at-Risk and Gross Tumor Volume of NPC for Radiotherapy Planning, SegRap
- Multi-Atlas Labeling Beyond the Cranial Vault, BTCV
- Liver Tumor Segmentation (LiTS) Challenge
- Kidney Tumor Segmentation (KiTS) Challenge
- Whole abdominal ORgan Dataset, WORD
- CT-ORG dataset
- Multi-Modality Abdominal Multi-Organ Segmentation (AMOS) Challenge
We use the MONAI framework in this repo, so just follow the data preparetion with MONAI segmentation examples. The most important is to prepare the dataset json files, containing the "training" and "validation" set. For each case in this json file, a "prompt" item is required along with the "image" key and "label" key. Here is the format of the json file:
{
"training": [
{
"image": "/path/to/image.nii.gz",
"label": "/path/to/mask.nii.gz",
"prompt": "xxx"
},
...
]
"validation": [
{
"image": "/path/to/image.nii.gz",
"label": "/path/to/mask.nii.gz",
"prompt": "xxx"
},
...
]
}
First, download the weight of the CLIP model used in our codes. I have uploaded the opensource weights in Google Drive or Baidu Yun (Code: 86uf) for the convenience of downloading. Please download all the files in this shared folder and find the following code in main.py
:
prompt_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32', device=torch.device('cuda:0'), cache_dir="/path/to/clip_weights")
Then, modify the /path/to/clip_weights
path to the location of the clip weights folder downloaded in your computer.
Training the model with automatic pathway modules from scrach:
python main.py --json_list=<json-path> --feature_size=48 --roi_x=128 --roi_y=128 --roi_z=128 --ap_num=6 --use_gradcheckpoint --batch_size=<batch-size> --max_epochs=<max-epochs> --save_checkpoint --logdir=v1-pretrain --use_prompt
The number of feature_size
should be divisible by 12, which is used to control the model size. In addition, use_gradcheckpoint
enables the use of gradient checkpointing for memory-efficient training. use_prompt
should not be removed when training the model with automatic pathway modules. ap_num
is the number of pathway in each automatic pathway module. save_checkpoint
is used to save the checkpoint. logdir
is the model saveing folder, which will be automatic created in the runs
folder. It will store the tensorboard information and the checkpoints. This command defaults to using amp for training, and you can add noamp
to close the amp training.
Training the model with automatic pathway modules from the pretrain model:
python main.py --json_list=<json-path> --feature_size=48 --roi_x=128 --roi_y=128 --roi_z=128 --ap_num=6 --use_gradcheckpoint --batch_size=<batch-size> --max_epochs=<max-epochs> --save_checkpoint --logdir=v1-pretrain --use_prompt --pretrained=<model-path>
Due to the commercial and privacy issue from our partners, the pretrained model on large dataset could not be included. We recommand to use some open-source pretrained weights. Thanks for the great work of SwinUNETR, you can download their self pretrained weights here (the feature_size
of this pretrained model is 48). Or, you can download their pretrained models on BTCV dataset: swin_base, swin_small, swin_tiny (feature_size
of these pretrained models are 48, 24, 12, respectively). You should use pretrained
to input the path of pretrained model. These weights can provide a good initialization point for the training of swin backbone.
After the model pretraining with multiple automatic pathway modules in large dataset, you can further finetune the model to get better performance. First, you should only select the data of the tasks belonging to each sub-pathway to further update the parameters while other parameters in the model are fixed. Then you use all the data to update the parameters except the automatic pathway modules while fixing the pathway selection. Since this process is quite complicated, we recommand to directly fix the swin encoder part, and finetune the model into each single task. Also, you can use this method to fast transfer the pretrained model into a new task with limited data.
First, you should find the select pathways corresponding to the target task.
Also you should find the following code in test.py
:
prompt_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32', device=torch.device('cuda:0'), cache_dir="/path/to/clip_weights")
Then, modify the /path/to/clip_weights
path to the path of the folder of clip weights in your computer.
python test.py --model=<pretrain-model-path> --json_list=<json-path> --feature_size=48 --roi_x=128 --roi_y=128 --roi_z=128 --ap_num=6 --use_prompt --display_path
model
should be set the path of the pretrained model pt file from the first step. ap_num
should match the training setting of the pretrained model. You can set only one case in the "validation" set of json file, and set the "prompt" into the target task. You can get the pathway index of each automatic pathway in the terminal like:
...
decoder10 layer (position 1) path index: 2
decoder5 layer (position 2) path index: 1
decoder4 layer (position 3) path index: 3
decoder3 layer (position 4) path index: 5
decoder2 layer (position 5) path index: 3
decoder1 layer (position 6) path index: 6
out layer (position 7) path index: 6
...
From this example, you can bulid a pathway index list: [2, 1, 3, 5, 3, 6, 6]
. Then you can finetune this pretrained model by:
python main.py --json_list=<json-path> --feature_size=48 --roi_x=128 --roi_y=128 --roi_z=128 --batch_size=<batch-size> --max_epochs=<max-epochs> --save_checkpoint --logdir=v1-finetune --finetune --ap_model=<pretrain-model-path> --path_selection="2,1,3,5,3,6,6"
path_selection
refers to the selected pathways, contains all the indexes in 7 automatic pathway modules. The index numbers are seperated with ",". finetune
is used to fix the parameters of swin backbone. ap_model
should be set the path of the pretrained model pt file from the first step. Since the swin backbone is fixed, do not add use_gradcheckpoint
here to avoid the gradient updating problem.
Due to the commercial and privacy issue from our partners, the finetune model on our large dataset could not be included.
Here, you can use the following scripts to get the mask result corresponding to the prompt input.
Also you should find the following code in test.py
:
prompt_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32', device=torch.device('cuda:0'), cache_dir="/path/to/clip_weights")
Then, modify the path /path/to/clip_weights
to the path of the folder of clip weights in your computer.
python test.py --model=<pretrain-model-path> --json_list=<json-path> --feature_size=48 --roi_x=128 --roi_y=128 --roi_z=128 --ap_num=6 --use_prompt
model
should be set to the path of the model pt file from the first pretraining step. ap_num
should match the training setting of the pretrained model. You should set the "prompt" of each cases into the target task in the json file.
Clip model is not need in this test script.
python test.py --model=<pretrain-model-path> --json_list=<json-path> --feature_size=48 --roi_x=128 --roi_y=128 --roi_z=128
model
should be set to the path of the model file from the second finetune step. The inference time varies for different image size and differnet computers. For a coronary CTA image (spacing: 0.42mm * 0.42mm * 0.42mm, size: 512 * 512 * 425), the inference time on NVIDIA A4000 card is less than one minute (using cropped patches of 96 * 96 * 96).
Due to the commercial and privacy issue from our partners, we release several weights for some downstream tasks in Google Drive or Baidu Yun (Code: vid7) for the convenience of downloading. In each folder of our released files, "params.pth" is the model weight, which has been trimmed from the multi-task model for the specific task. So all the model files share the same encoder weights. Also, we release some test cases of each task for quick verification. The "xxx.csv" contains the information for the original images and masks of the test cases.
We thank for these great opensource projects: MONAI, Open CLIP, and SwinUNETR.