jupyter | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
Learn how to train a 3D convolutional neural network (3D CNN) to predict presence of pneumonia - based on Tutorial on 3D Image Classification by Hasib Zunair.
Dataset: MosMedData: Chest CT Scans with COVID-19 Related Findings Dataset :: This dataset contains anonymised human lung computed tomography (CT) scans with COVID-19 related findings, as well as without such findings. A small subset of studies has been annotated with binary pixel masks depicting regions of interests (ground-glass opacifications and consolidations). CT scans were obtained between 1st of March, 2020 and 25th of April, 2020, and provided by municipal hospitals in Moscow, Russia.
# importing tensorflow
import tensorflow as tf
device_name = tf.test.gpu_device_name()
print('Active GPU :: {}'.format(device_name))
# Active GPU :: /device:GPU:0
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import os
import random
from scipy import ndimage
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
# helper functions
from helper import (read_scan,
normalize,
resize_volume,
process_scan,
rotate,
train_preprocessing,
validation_preprocessing,
plot_slices,
build_model)
# download from https://github.com/hasibzunair/3D-image-classification-tutorial/releases/
data_dir = './dataset'
no_pneumonia = os.path.join(data_dir, 'no_viral_pneumonia')
with_pneumonia = os.path.join(data_dir, 'with_viral_pneumonia')
normal_scan_paths = [
os.path.join(no_pneumonia, i)
for i in os.listdir(no_pneumonia)
]
print('INFO :: CT Scans with normal lung tissue:', len(normal_scan_paths))
abnormal_scan_paths = [
os.path.join(with_pneumonia, i)
for i in os.listdir(with_pneumonia)
]
print('INFO :: CT Scans with abnormal lung tissue:', len(abnormal_scan_paths))
# INFO :: CT Scans with normal lung tissue: 100
# INFO :: CT Scans with abnormal lung tissue: 100
img_normal = nib.load(normal_scan_paths[0])
img_normal_array = img_normal.get_fdata()
img_abnormal = nib.load(abnormal_scan_paths[0])
img_abnormal_array = img_abnormal.get_fdata()
plt.figure(figsize=(30,10))
for i in range(6):
plt.subplot(2, 6, i+1)
plt.imshow(img_normal_array[:, :, i], cmap='Blues')
plt.axis('off')
plt.title('Slice {} - Normal'.format(i))
plt.subplot(2, 6, 6+i+1)
plt.imshow(img_abnormal_array[:, :, i], cmap='Reds')
plt.axis('off')
plt.title('Slice {} - Abnormal'.format(i))
# Read and process the scans.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
# For the CT scans having presence of viral pneumonia
# assign 1, for the normal ones assign 0.
abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
normal_labels = np.array([0 for _ in range(len(normal_scans))])
X = np.concatenate((abnormal_scans, normal_scans), axis=0)
Y = np.concatenate((abnormal_labels, normal_labels), axis=0)
x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=0.3, random_state=42)
print('INFO :: Train / Test Samples - %d / %d' % (x_train.shape[0], x_val.shape[0]))
# INFO :: Train / Test Samples - 140 / 60
# Define data loaders.
train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
batch_size = 2
# Augment the on the fly during training.
train_dataset = (
train_loader.shuffle(len(x_train))
.map(train_preprocessing)
.batch(batch_size)
.prefetch(2)
)
# Only rescale.
validation_dataset = (
validation_loader.shuffle(len(x_val))
.map(validation_preprocessing)
.batch(batch_size)
.prefetch(2)
)
data = train_dataset.take(1)
images, labels = list(data)[0]
images = images.numpy()
image = images[0]
print("CT Scan Dims:", image.shape)
# CT Scan Dims: (128, 128, 64, 1)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
# Visualize montage of slices.
# 4 rows and 10 columns for 100 slices of the CT scan.
plot_slices(4, 10, 128, 128, image[:, :, :40])
model = build_model(width=128, height=128, depth=64)
model.summary()
Model: "3dctscan"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_3 (InputLayer) [(None, 128, 128, 64, 1) 0
]
conv3d_9 (Conv3D) (None, 126, 126, 62, 64) 1792
max_pooling3d_8 (MaxPooling (None, 63, 63, 31, 64) 0
3D)
batch_normalization_8 (Batc (None, 63, 63, 31, 64) 256
hNormalization)
conv3d_10 (Conv3D) (None, 61, 61, 29, 64) 110656
max_pooling3d_9 (MaxPooling (None, 30, 30, 14, 64) 0
3D)
batch_normalization_9 (Batc (None, 30, 30, 14, 64) 256
hNormalization)
conv3d_11 (Conv3D) (None, 28, 28, 12, 128) 221312
max_pooling3d_10 (MaxPoolin (None, 14, 14, 6, 128) 0
g3D)
batch_normalization_10 (Bat (None, 14, 14, 6, 128) 512
chNormalization)
conv3d_12 (Conv3D) (None, 12, 12, 4, 256) 884992
max_pooling3d_11 (MaxPoolin (None, 6, 6, 2, 256) 0
g3D)
batch_normalization_11 (Bat (None, 6, 6, 2, 256) 1024
chNormalization)
global_average_pooling3d_1 (None, 256) 0
(GlobalAveragePooling3D)
dense_2 (Dense) (None, 512) 131584
dropout_1 (Dropout) (None, 512) 0
dense_3 (Dense) (None, 1) 513
=================================================================
Total params: 1,352,897
Trainable params: 1,351,873
Non-trainable params: 1,024
_________________________________________________________________
initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
decay_steps=100000,
decay_rate=0.96,
staircase=True)
model.compile(
loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
metrics=['acc']
)
cp_cb = keras.callbacks.ModelCheckpoint(
'./checkpoints/3dct_weights.h5',
save_best_only=True
)
es_cb = keras.callbacks.EarlyStopping(
monitor='val_acc',
patience=15
)
epochs = 100
model.fit(
train_dataset,
validation_data=validation_dataset,
epochs=epochs,
shuffle=True,
verbose=2,
callbacks=[cp_cb, es_cb]
)
# Epoch 46/100
# 70/70 - 22s - loss: 0.3383 - acc: 0.8429 - val_loss: 0.8225 - val_acc: 0.6833 - 22s/epoch - 313ms/step
fig, ax = plt.subplots(1, 2, figsize=(20, 3))
ax = ax.ravel()
for i, metric in enumerate(["acc", "loss"]):
ax[i].plot(model.history.history[metric])
ax[i].plot(model.history.history["val_" + metric])
ax[i].set_title("Model {}".format(metric))
ax[i].set_xlabel("epochs")
ax[i].set_ylabel(metric)
ax[i].legend(["train", "val"])
model.load_weights('./checkpoints/3dct_weights.h5')
pred_dataset = './predictions'
pred_paths = [os.path.join(pred_dataset, i) for i in os.listdir(pred_dataset)]
z_val = np.array([process_scan(path) for path in pred_paths])
for i in range(len(z_val)):
prediction = model.predict(np.expand_dims(z_val[i], axis=0))[0]
scores = [1 - prediction[0], prediction[0]]
class_names = ['normal', 'abnormal']
pred_image = nib.load(pred_paths[i])
pred_image_data = pred_image.get_fdata()
normal_class = class_names[0], round(100*scores[0], 2)
abnormal_class = class_names[1], round(100*scores[1], 2)
annotation = normal_class + abnormal_class
plt.imshow(pred_image_data[:,:, pred_image_data.shape[2]//2], cmap='gray')
plt.title(os.path.basename(pred_paths[i]))
plt.xlabel(annotation)
plt.show()
print(scores, class_names)