Image Classification Example with CNN
In this lesson, we will build a simple image classification model using TensorFlow
and Keras
, and practice classifying digit images (0-9) using the MNIST dataset.
1. Preparing the Data
First, let's load the MNIST
dataset provided by TensorFlow. This dataset consists of 28×28 pixel grayscale images of handwritten digits.
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
# Load MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data to range 0 to 1
x_train, x_test = x_train / 255.0, x_test / 255.0
# Expand dimensions (CNN requires 3D input)
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]
# Display a data sample
plt.imshow(x_train[0].squeeze(), cmap='gray')
plt.title(f"Label: {y_train[0]}")
plt.show()
2. Creating the CNN Model
A CNN model consists of convolutional layers (Conv2D), pooling layers (MaxPooling2D), and fully connected layers (Dense).
# Define the CNN model
model = keras.Sequential([
keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Conv2D(64, (3,3), activation='relu'),
keras.layers.MaxPooling2D(2,2),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
# Display the model architecture
model.summary()
# Compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
3. Training the Model
Now, let's train the model with the prepared data.
# Train the model
history = model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))
During training, you can observe the accuracy and loss values decreasing.
4. Evaluating and Predicting with the Model
Next, evaluate the trained model and make predictions.
# Evaluate the model
loss, acc = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {acc:.4f}")
Perform predictions on specific samples to check the results.
# Predict sample data
sample = x_test[:5] # First 5 images
predictions = model.predict(sample)
predicted_labels = np.argmax(predictions, axis=1)
# Display prediction results
for i in range(5):
plt.imshow(sample[i].squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted_labels[i]}, Actual: {y_test[i]}")
plt.show()
Now, you have built a simple model to classify handwritten digit images using CNN. You can apply this model to various image datasets to build more robust image classification models.
In the next lesson, we will engage in a quiz to review what we've learned so far.
Want to learn more?
Join CodeFriends Plus membership or enroll in a course to start your journey.