Classifying Images from MNIST Databases using a CNN
Have you ever wondered how a computer can recognize images? I trained an AI model which can recognize handwritten numbers and even fashion items!
Curious to see what those images look like?
The images above are just a sample of 70,000 unique images in each dataset.
Convolutional Neural Network (CNN)
A CNN is a neural network which specializes in processing grid-like data, such as an image. An image can be thought of as a grid with each pixel being a numerical value. For example, each pixel in the images above can be represented as a number between 0 and 1, indicating how bright it is. A CNN is particularly good at classifying images due to their ability to learn unique features and shapes. The architecture of a CNN is inspired by the visual cortex of the brain and is structured in a way that it can identify various features in images, such as edges, textures, and patterns.
CNNs have a variety of applications in the world.
CNNs influence our daily lives in multiple different ways. Here are a few examples:
- Self-driving cars. Tesla is training an AI model in its custom-built supercomputer using over 5 billion kilometres of driving data to drive cars autonomously.
- Healthcare. CNNs help healthcare professionals detect and diagnose anomalies, such as tumours, and let them work on patients who need their care first.
- Quality of life. Google Photos uses CNNs to recognize and classify faces from your gallery of photos, which is useful for looking through photos easily.
Building the model
To build this module, I followed a tutorial from NeuralNine on building a CNN using the Python programming language and PyTorch library.
Python is a programming language which is widely used for machine learning applications. Python is praised for ease of use, readability, and for being easy to learn. Although it is much slower than other programming languages, there is a thriving ecosystem of libraries built in faster programming languages which can be used in Python.
PyTorch is an open-source machine learning framework developed by Facebook’s AI team that is widely used for applications in deep learning, such as computer vision.
1. Install libraries
pip install torch torchvision matplotlib
Before writing any code, I had to install the libraries I would use.
torch
is the PyTorch library consisting of the framework to build a neural network model.torchvision
has the datasets with handwritten digits and fashion items.matplotlib
is just a plotting library which can help display images and data better.
2. Set up the data
from torchvision import datasets
from torchvision.transforms import ToTensor
train_data = datasets.MNIST(
root="data",
train=True,
transform=ToTensor(),
download=True
)
test_data = datasets.MNIST(
root="data",
train=False,
transform=ToTensor(),
download=True
)
Here, I am importing the datasets and assigning the data to variables. The dataset of 70,000 images is split into training data (60,000) and testing data (10,000).
The CNN will be trained using the training data and tested on the testing data. This is to ensure that the CNN can classify images outside of its training data. Otherwise, if the dataset was not split into training and testing datasets, then the CNN will only work for the training data and improperly classify other images!
To use the fashion items dataset, use datasets.FashionMNIST(...)
instead.
3. Load the data into batches
from torch.utils.data import DataLoader
loaders = {
"train": DataLoader(
train_data,
batch_size=100,
shuffle=True,
num_workers=1
),
"test": DataLoader(
test_data,
batch_size=100,
shuffle=True,
num_workers=1
),
}
This code loads the data in many shuffled batches. By splitting up the images into smaller groups, the CNN can be trained and tested multiple times, with the accuracy increasing each time.
4. The actual model
from torch import nn, optim, Tensor
from torch.nn import functional as F
class CNN(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x) -> Tensor:
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320) # 20 * 4 * 4 = 320
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.softmax(x, dim=1)
Here is where the model is actually created.
- The necessary libraries are imported.
- I define a class called CNN which inherits from
nn.Module
to create a blueprint of what the model should do. - When the class is initialized, the
__init__
function is called, which sets up the model with code on what it should do:
self.conv1
: A 2D convolutional layer that takes a single-channel input image and produces 10 feature maps, using a 5x5 kernel size.self.conv2
: Another 2D convolutional layer that takes the 10 feature maps from the previous layer and produces 20 feature maps, using a 5x5 kernel size.self.conv2_drop
: A dropout layer that randomly zeroes out some of the feature maps to prevent overfitting.self.fc1
: A fully connected layer that takes the flattened feature maps (320 elements) and produces 50 output units.self.fc2
: Another fully connected layer that takes the 50 output units from the previous layer and produces 10 output units, corresponding to the number of classes.
4. The forward()
function is used to transform the input image to the output, which is a probability for each classification. Essentially, this is where the computing is done and where the layers move forward in.
x = F.relu(F.max_pool2d(self.conv1(x), 2))
: The input is passed through the first convolutional layer, then a rectified linear unit (ReLU) activation function is applied, then a 2x2 max pooling operation is performed to reduce the spatial dimensions of the feature maps.x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
: The same steps are repeated for the second convolutional layer, except that a dropout layer is applied before the ReLU activation function.x = x.view(-1, 320)
: The feature maps are reshaped into a one-dimensional vector of 320 elements, which is the input for the fully connected layers.x = F.relu(self.fc1(x))
: The vector is passed through the first fully connected layer, then a ReLU activation function is applied.x = F.dropout(x, training=self.training)
: A dropout layer is applied to the output of the first fully connected layer, with a different probability depending on whether the network is in training or evaluation mode.x = self.fc2(x)
: The output of the dropout layer is passed through the second fully connected layer, which produces 10 output units.return F.softmax(x, dim=1)
: A softmax function is applied to the output units, which normalizes them into probabilities that sum up to one.
A ReLU is just a simple function f(x) = max(0, x)
which just returns the bigger number between 0 and the value. It returns the value if it is greater than 0, otherwise, it returns 0.
Here is a cool diagram showing how each one of the steps work:
5. Training and testing the CNN
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
def train(epoch: int):
model.train()
for batch_idx, (data, target) in enumerate(loaders["train"]):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward() # back propagate
optimizer.step()
if batch_idx % 20 == 0:
idx_data = batch_idx * len(data)
total_data = len(loaders["train"].dataset)
amt_loss = 100. * batch_idx / len(loaders["train"])
num_loss = loss.item()
print(f"Train {epoch=} [{idx_data:>5}/{total_data} ({amt_loss:2.0f}%)]\t{num_loss:.6f}")
def test() -> float:
model.eval()
test_loss: int = 0
correct: int = 0
with torch.no_grad():
for data, target in loaders["test"]:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += loss_fn(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total_data = len(loaders["test"].dataset)
test_loss /= total_data
calc: float = 100. * correct / total_data
print(f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total_data} ({calc:.2f}%)\n")
return calc
This is where the model is initialized, trained, and tested, using the train()
and test()
functions.
- First, I set the device to the GPU if there is one available, otherwise it will use the CPU. The GPU is much faster at processing calculations like multiplying numbers, which is used a lot in a neural network.
- I initialize the CNN model and set the device to the above.
- Using the Adam algorithm, the model’s parameters, which are the weights and connections of the network, are updated with a learning rate of 0.001.
- The loss function object will compute the cross-entropy loss between the model output and the target labels. The cross-entropy loss measures how well the model predicts the correct class for each input image.
- The
train()
function is used for training the CNN:
model.train()
: This sets the model to training mode, which enables the dropout layers to randomly drop some units during the forward pass. This helps to run the model faster and train it without sacrificing accuracy.for batch_idx, (data, target) in enumerate(loaders["train"])
: This loops over the training data loader, which is an object that provides batches of data and labels from the training dataset. The loop index isbatch_idx
, and the batch data and labels aredata
andtarget
, respectively.data, target = data.to(device), target.to(device)
: This moves the batch data and labels to the device, so that they can be processed by the model and the loss function.optimizer.zero_grad()
: This resets the gradients of the model parameters to zero, so that they can be accumulated during the backward pass.output = model(data)
: This passes the batch data through the model, which performs the forward pass and returns the output tensor of probabilities for each class.loss = loss_fn(output, target)
: This passes the output tensor and the target labels to the loss function, which computes the cross-entropy loss value.loss.backward()
: This performs the backward pass, which computes the gradients of the loss with respect to the model parameters and accumulates them in the parameter attributes.optimizer.step()
: This performs the optimization step, which updates the model parameters using the gradients and the learning rate.if batch_idx % 20 == 0
: This part of the code just prints the progress of the training and the calculated loss value, every 20 batches.
6. The test()
function is used to test the CNN on the test data, which has a few key differences compared to the train()
function:
- The model is set to evaluation mode instead of testing mode,
- The gradients are not computed since the parameters are not updated when testing the module, saving some computing time,
pred = output.argmax(dim=1, keepdim=True)
: This computes the predicted class for each image, by finding the index of the maximum value along the second dimension of the output tensor, and keeping the same shape as the output tensor,- And the accuracy of the predictions are tracked and printed.
6. Running the functions
accuracy_per_epoch = []
for epoch in range(1, 11):
train(epoch)
accuracy_per_epoch.append(test())
The module is trained and tested 10 times, with the accuracy of the model stored for future use.
7. Plotting the first 10 predictions
import matplotlib.pyplot as plt
model.eval()
for i in range(10):
data, target = test_data[i]
data = data.unsqueeze(0).to(device)
output = model(data)
prediction = output.argmax(dim=1, keepdim=True).item()
print(f"{prediction=}")
image = data.squeeze(0).squeeze(0).cpu().numpy()
plt.imshow(image, cmap="gray")
plt.show()
Here, I printed the first 10 predictions of the model to see if it works!
8. Running it for even longer
for epoch in range(11, 51):
train(epoch)
accuracy_per_epoch.append(test())
I chose to run it 40 more times to make the model more accurate.
9. Plotting the accuracy over the test runs
x_values = range(1, 51)
plt.plot(x_values, accuracy_per_epoch, marker="o", linestyle="-")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy per Epoch")
plt.show()
As you can see, the model is 98.5% accurate at classifying handwritten images! The model only took around 10 minutes to train with a T4 GPU.
The CNN which was trained on the FashionMNIST
dataset had an accuracy of 86.7%, taking a similar amount of time to train and test.
Analysis
The CNN for the FashionMNIST dataset had a lower accuracy than the CNN which was trained on handwritten digits. This is because the FashionMNIST images have more complexity and ambiguity between the different categories.
CNNs “see” simple features from images in the first layers, then as the neural network progresses, the model picks up on more complex and finer details which are combinations of the simple features.
When you look at the different images, the numbers have fewer features, which are more simple. For example, there are mostly straight lines, curved lines, and circles. However, the fashion items look more complicated, with different outlines, fill, and shapes, making them harder to classify.
To overcome this issue, the depth of the neural network needs to be increased by adding more layers and neurons. However, it will take more computing power and more time to train and run the model, plus the accuracy improvement flattens and appears logarithmic as time goes on.
Conclusion
I built a Convolutional Neural Network (CNN) which can classify images such as handwritten digits or fashion items. Due to the small image sizes, I was able to develop and run the model quickly; training the models took around 10 minutes on Google Colab. CNNs have many influential applications in many different areas, and I enjoyed learning more about them and building my own. Thanks for reading my article!
You can reach out to me here:
- Email: prajwal028@outlook.com
- LinkedIn: https://www.linkedin.com/in/prajwal-prashanth/
- Substack: https://substack.com/@prajwalprashanth