Training scripts for modification, fine-tuning and evaluation of pretrained Image Classification models with PyTorch to classify a new dataset of interest. This example uses models pretrained on ImageNet (1000 general object classes) to make predictions on images in the Food 101 Dataset (101 food-specific classes).
The Food-101 data set consists of 101 food categories, with 101,000 images in total. The dataset is split into pre-defined train and test sets. Each image category includes 750 training images and 250 test images.
For training, 20% of the training dataset is held and used for validation. All evaluation is performed on the test dataset.
Pytorch implementation of AlexNet from ImageNet Classification with Deep Convolutional Neural Networks paper. Network is pretrained on ImageNet and final fully connected layer is replaced with a 101-unit fully connected layer.
Implementation of models from ResNet Convolutional Neural Network paper. For this task, PyTorch implementations of ResNet18 and ResNet50 models, pretrained on ImageNet, are used.
Skip Connections add the original input to the output of the convolutional block
The final fully connected layer of the pretrained ResNet models is replaced with a 101-unit fully connected layer.
See tensorboard for full training experiment results. Training was limited to 10 epochs; results for each model would improve with additional epochs.
On the 25,250 image test set, the best overall accuracy was 77.9%, via ResNet50 model:
Accuracy varies by class and is shown below:
With conda installed, create and activate environment with the following bash commands:
>>> conda env create -f environment.yml
>>> conda activate py310_torch
python train.py --model resnet50 --workers 8 --gpu True --epochs 1 --warm_start True
Optional parameters:
--model
- Specifies model to train:
alexnet
: AlexNet, pretrained on Imagenetresnet18
: Resnet 18, pretrained on Imagenetresnet50
: Resnet 50, pretrained on Imagenet- Can specify any new model by adding to
model.py
- Specifies model to train:
--workers
: specifies number of workers for dataloaders--gpu
:True
: Runs on CUDA or MPSFalse
: Runs on CPU
--epochs
: Number of training cycles through full dataset--warm_start
:True
: Loads pretrained model if prior training was runFalse
: Trains new model