Skip to content

mpolinowski/deep-3d-image-segmentation

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jupyter
jupytext kernelspec
formats text_representation
ipynb,md
extension format_name format_version jupytext_version
.md
markdown
1.3
1.14.4
display_name language name
Python 3 (ipykernel)
python
python3

3D Image Classification

Learn how to train a 3D convolutional neural network (3D CNN) to predict presence of pneumonia - based on Tutorial on 3D Image Classification by Hasib Zunair.

Dataset: MosMedData: Chest CT Scans with COVID-19 Related Findings Dataset :: This dataset contains anonymised human lung computed tomography (CT) scans with COVID-19 related findings, as well as without such findings. A small subset of studies has been annotated with binary pixel masks depicting regions of interests (ground-glass opacifications and consolidations). CT scans were obtained between 1st of March, 2020 and 25th of April, 2020, and provided by municipal hospitals in Moscow, Russia.

Verify GPU Support

# importing tensorflow
import tensorflow as tf

device_name = tf.test.gpu_device_name()
print('Active GPU :: {}'.format(device_name))
# Active GPU :: /device:GPU:0

Import Dependencies

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import random
from scipy import ndimage
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
# helper functions
from helper import (read_scan,
                    normalize,
                    resize_volume,
                    process_scan,
                    rotate,
                    train_preprocessing,
                    validation_preprocessing,
                    plot_slices,
                    build_model)

Import Dataset

# download from https://github.com/hasibzunair/3D-image-classification-tutorial/releases/
data_dir = './dataset'
no_pneumonia = os.path.join(data_dir, 'no_viral_pneumonia')
with_pneumonia = os.path.join(data_dir, 'with_viral_pneumonia')

normal_scan_paths = [
    os.path.join(no_pneumonia, i)
    for i in os.listdir(no_pneumonia)
]
print('INFO :: CT Scans with normal lung tissue:', len(normal_scan_paths))

abnormal_scan_paths = [
    os.path.join(with_pneumonia, i)
    for i in os.listdir(with_pneumonia)
]
print('INFO :: CT Scans with abnormal lung tissue:', len(abnormal_scan_paths))

# INFO :: CT Scans with normal lung tissue: 100
# INFO :: CT Scans with abnormal lung tissue: 100

Visualize Dataset

img_normal = nib.load(normal_scan_paths[0])
img_normal_array = img_normal.get_fdata()

img_abnormal = nib.load(abnormal_scan_paths[0])
img_abnormal_array = img_abnormal.get_fdata()

plt.figure(figsize=(30,10))

for i in range(6):
    plt.subplot(2, 6, i+1)
    plt.imshow(img_normal_array[:, :, i], cmap='Blues')
    plt.axis('off')
    plt.title('Slice {} - Normal'.format(i))
    
    plt.subplot(2, 6, 6+i+1)
    plt.imshow(img_abnormal_array[:, :, i], cmap='Reds')
    plt.axis('off')
    plt.title('Slice {} - Abnormal'.format(i))

3D Image Classification

Data Pre-processing

Normalization

# Read and process the scans.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
# For the CT scans having presence of viral pneumonia
# assign 1, for the normal ones assign 0.
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])

Train Test Split

X = np.concatenate((abnormal_scans, normal_scans), axis=0)
Y = np.concatenate((abnormal_labels, normal_labels), axis=0)

x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=0.3, random_state=42)
print('INFO :: Train / Test Samples - %d / %d' % (x_train.shape[0], x_val.shape[0]))
# INFO :: Train / Test Samples - 140 / 60

Data Augmentation

Data Loader

# Define data loaders.
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
batch_size = 2

# Augment the on the fly during training.
train_dataset = (
    train_loader.shuffle(len(x_train))
    .map(train_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)
# Only rescale.
validation_dataset = (
    validation_loader.shuffle(len(x_val))
    .map(validation_preprocessing)
    .batch(batch_size)
    .prefetch(2)
)

Visualizing Augmented Datasets

data = train_dataset.take(1)
images, labels = list(data)[0]
images = images.numpy()
image = images[0]
print("CT Scan Dims:", image.shape)
# CT Scan Dims: (128, 128, 64, 1)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")

# Visualize montage of slices.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 128, image[:, :, :40])

3D Image Classification

3D Image Classification

Building the Model

model = build_model(width=128, height=128, depth=64)
model.summary()
Model: "3dctscan"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 128, 128, 64, 1)  0         
                             ]                                   
                                                                 
 conv3d_9 (Conv3D)           (None, 126, 126, 62, 64)  1792      
                                                                 
 max_pooling3d_8 (MaxPooling  (None, 63, 63, 31, 64)   0         
 3D)                                                             
                                                                 
 batch_normalization_8 (Batc  (None, 63, 63, 31, 64)   256       
 hNormalization)                                                 
                                                                 
 conv3d_10 (Conv3D)          (None, 61, 61, 29, 64)    110656    
                                                                 
 max_pooling3d_9 (MaxPooling  (None, 30, 30, 14, 64)   0         
 3D)                                                             
                                                                 
 batch_normalization_9 (Batc  (None, 30, 30, 14, 64)   256       
 hNormalization)                                                 
                                                                 
 conv3d_11 (Conv3D)          (None, 28, 28, 12, 128)   221312    
                                                                 
 max_pooling3d_10 (MaxPoolin  (None, 14, 14, 6, 128)   0         
 g3D)                                                            
                                                                 
 batch_normalization_10 (Bat  (None, 14, 14, 6, 128)   512       
 chNormalization)                                                
                                                                 
 conv3d_12 (Conv3D)          (None, 12, 12, 4, 256)    884992    
                                                                 
 max_pooling3d_11 (MaxPoolin  (None, 6, 6, 2, 256)     0         
 g3D)                                                            
                                                                 
 batch_normalization_11 (Bat  (None, 6, 6, 2, 256)     1024      
 chNormalization)                                                
                                                                 
 global_average_pooling3d_1   (None, 256)              0         
 (GlobalAveragePooling3D)                                        
                                                                 
 dense_2 (Dense)             (None, 512)               131584    
                                                                 
 dropout_1 (Dropout)         (None, 512)               0         
                                                                 
 dense_3 (Dense)             (None, 1)                 513       
                                                                 
=================================================================
Total params: 1,352,897
Trainable params: 1,351,873
Non-trainable params: 1,024
_________________________________________________________________

Compile the Model

initial_learning_rate = 0.0001

lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True)

model.compile(
    loss='binary_crossentropy',
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    metrics=['acc']
)

Callbacks

cp_cb = keras.callbacks.ModelCheckpoint(
    './checkpoints/3dct_weights.h5',
    save_best_only=True
)

es_cb = keras.callbacks.EarlyStopping(
    monitor='val_acc',
    patience=15
)

Model Training

epochs = 100

model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    shuffle=True,
    verbose=2,
    callbacks=[cp_cb, es_cb]
)
# Epoch 46/100
# 70/70 - 22s - loss: 0.3383 - acc: 0.8429 - val_loss: 0.8225 - val_acc: 0.6833 - 22s/epoch - 313ms/step

Visualizing Model Performance

fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()

for i, metric in enumerate(["acc", "loss"]):
    ax[i].plot(model.history.history[metric])
    ax[i].plot(model.history.history["val_" + metric])
    ax[i].set_title("Model {}".format(metric))
    ax[i].set_xlabel("epochs")
    ax[i].set_ylabel(metric)
    ax[i].legend(["train", "val"])

3D Image Classification

Loading Best Training Weights

model.load_weights('./checkpoints/3dct_weights.h5')

Make Predictions

pred_dataset = './predictions'
pred_paths = [os.path.join(pred_dataset, i) for i in os.listdir(pred_dataset)]

z_val = np.array([process_scan(path) for path in pred_paths])

for i in range(len(z_val)):
    prediction = model.predict(np.expand_dims(z_val[i], axis=0))[0]
    scores = [1 - prediction[0], prediction[0]]
    class_names = ['normal', 'abnormal']
    
pred_image = nib.load(pred_paths[i])
pred_image_data = pred_image.get_fdata()

normal_class = class_names[0], round(100*scores[0], 2)
abnormal_class = class_names[1], round(100*scores[1], 2)
annotation = normal_class + abnormal_class

plt.imshow(pred_image_data[:,:, pred_image_data.shape[2]//2], cmap='gray')
plt.title(os.path.basename(pred_paths[i]))
plt.xlabel(annotation)
plt.show()

3D Image Classification

print(scores, class_names)