This project demonstrates image classification using a Convolutional Neural Network (CNN) implemented with PyTorch. The CNN is trained to classify images of fruits into different classes. It utilizes the torchvision library to preprocess the data and create a dataset for training.
The project includes the following components:
- Data preprocessing and loading using
torchvision.datasets.ImageFolder
. - Definition of a CNN model with PyTorch's
nn.Module
. - Training the model using a custom training loop.
- Visualization of predicted class labels on test images.
Below is an example output showing predicted class labels for test images:
To run the project:
- Install the required libraries listed in
requirements.txt
. - Organize your dataset in the
fruits
directory. - Run the Python script
train_model.py
to train the CNN model. - Use the trained model to predict class labels using
predict_label.py
.
- Python 3.x
- PyTorch
- torchvision
- matplotlib
This project is licensed under the Apache 2.0 License. See LICENSE for details.