Skip to content

A capsule network (paper "Dynamic Routing Between Capsules") is trained from scratch in Pytorch for Handwritten Digits Classification Problem.

Notifications You must be signed in to change notification settings

HoangPham3003/Capsule-Network-in-Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Capsule Network in Pytorch

A Pytorch implementation of CapsNet based on Geoffrey Hinton's paper Dynamic Routing Between Capsules

Requirements

  • Python 3
  • PyTorch
  • Torchvision
  • Pillow
  • Opencv-python
  • Tqdm
  • Matplotlib (for showing some inferences after training)

Hyperparameters

The hyperparameters I used can be different from other authors' a bit.

NUM_EPOCHS = 100
BATCH_SIZE = 128
- In the Encoder:
    LEARNING_RATE = 0.0003
    WEIGHT_DECAY = 0.0001
- In the Decoder:
    LEARNING_RATE = 0.001

Model architecture

The architecture of the model was followed strictly as mentioned in the paper and was defined in CapsNet.py

Training Procedure

The training procedure can be followed up in the MNIST_CapsuleNetwork.ipynb

Benchmarks

Highest accuracy was 99.3% on the 85/100 epoch. In this epoch, the Final Loss was 0.0151, in which Margin Loss was 0.0094 and Reconstruction Loss was 11.4979. The best model was saved and can be found here The trend of loss and accuracy graphs were shown below. result_plot

Inferences

A simple app was created by Tkinter for inferences. The code in run.py can be downloaded and run for inferences. result_infer

References

Very thanks to some authors who have guidelines and implementations of CapsNet:

  1. TensorFlow implementation by @naturomics
  2. Pytorch implementation by @gram-ai
  3. Detail guidelines in Tensorflow by @ageron

About

A capsule network (paper "Dynamic Routing Between Capsules") is trained from scratch in Pytorch for Handwritten Digits Classification Problem.

Topics

Resources

Stars

Watchers

Forks