Skip to content

Training script for fine-tuning pretrained image classification models with PyTorch

Notifications You must be signed in to change notification settings

trevorwitter/PyTorch-ImageClassifier-FineTuning

Repository files navigation

Fine-tuning Image Classification Models with PyTorch

Training scripts for modification, fine-tuning and evaluation of pretrained Image Classification models with PyTorch to classify a new dataset of interest. This example uses models pretrained on ImageNet (1000 general object classes) to make predictions on images in the Food 101 Dataset (101 food-specific classes).

Table of Contents

Dataset

Food-101 Data Set

The Food-101 data set consists of 101 food categories, with 101,000 images in total. The dataset is split into pre-defined train and test sets. Each image category includes 750 training images and 250 test images.

For training, 20% of the training dataset is held and used for validation. All evaluation is performed on the test dataset.

Model Architectures

AlexNet

Pytorch implementation of AlexNet from ImageNet Classification with Deep Convolutional Neural Networks paper. Network is pretrained on ImageNet and final fully connected layer is replaced with a 101-unit fully connected layer.

AlexNet

ResNet18, ResNet50

Implementation of models from ResNet Convolutional Neural Network paper. For this task, PyTorch implementations of ResNet18 and ResNet50 models, pretrained on ImageNet, are used.

ResNet50 Architecture

Skip Connections add the original input to the output of the convolutional block Skip Connection

The final fully connected layer of the pretrained ResNet models is replaced with a 101-unit fully connected layer.

Results

See tensorboard for full training experiment results. Training was limited to 10 epochs; results for each model would improve with additional epochs.

On the 25,250 image test set, the best overall accuracy was 77.9%, via ResNet50 model:

Overall Accuracy

Accuracy varies by class and is shown below:

Test Accuracy by Class

Quickstart

With conda installed, create and activate environment with the following bash commands:

>>> conda env create -f environment.yml
>>> conda activate py310_torch

Training

python train.py --model resnet50 --workers 8 --gpu True --epochs 1 --warm_start True

Optional parameters:

  • --model
    • Specifies model to train:
      • alexnet: AlexNet, pretrained on Imagenet
      • resnet18: Resnet 18, pretrained on Imagenet
      • resnet50: Resnet 50, pretrained on Imagenet
      • Can specify any new model by adding to model.py
  • --workers: specifies number of workers for dataloaders
  • --gpu:
    • True: Runs on CUDA or MPS
    • False: Runs on CPU
  • --epochs: Number of training cycles through full dataset
  • --warm_start:
    • True: Loads pretrained model if prior training was run
    • False: Trains new model

About

Training script for fine-tuning pretrained image classification models with PyTorch

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published