Skip to content

cnn_mnist

Example Script: cnn-mnist.py

This script demonstrates how to use the neural network implementation from the SciREX library to perform classification on the MNIST dataset.

This example includes
  • Loading the MNIST dataset using tensorflow.keras.datasets
  • Training Convolutional Neural Networks
  • Evaluating and visualizing the results
Key Features
  • Uses cross-entropy loss for training
  • Implements accuracy metric for evaluation
  • Includes model checkpointing
  • Provides training history visualization
Authors
  • Lokesh Mohanty (lokeshm@iisc.ac.in)
Version Info
  • 04/01/2024: Initial version
  • 06/01/2024: Update imports

CNN

Bases: Network

Convolutional Neural Network for MNIST digit classification.

Architecture: - Conv2D: 1->4 channels, 4x4 kernel - MaxPool2D: 2x2 pooling - ReLU activation - Conv2D: 4->8 channels, 4x4 kernel - MaxPool2D: 2x2 pooling - ReLU activation - Flatten - Dense: 844->10 units - LogSoftmax activation

Source code in examples/dl/cnn_mnist.py
class CNN(Network):
    """
    Convolutional Neural Network for MNIST digit classification.

    Architecture:
    - Conv2D: 1->4 channels, 4x4 kernel
    - MaxPool2D: 2x2 pooling
    - ReLU activation
    - Conv2D: 4->8 channels, 4x4 kernel
    - MaxPool2D: 2x2 pooling
    - ReLU activation
    - Flatten
    - Dense: 8*4*4->10 units
    - LogSoftmax activation
    """

    layers: list

    def __init__(self):
        """Initialize the CNN architecture with predefined layers."""
        self.layers = [
            nn.Conv2d(1, 4, kernel_size=4, key=key1),  # First conv layer: 1->4 channels
            nn.MaxPool2d(2, 2),  # Reduce spatial dimensions
            nn.relu,  # Activation function
            nn.Conv2d(
                4, 8, kernel_size=4, key=key1
            ),  # Second conv layer: 4->8 channels
            nn.MaxPool2d(2, 2),  # Further reduce dimensions
            nn.relu,  # Activation function
            jnp.ravel,  # Flatten for dense layer
            nn.Linear(8 * 4 * 4, 10, key=key2),  # Output layer: 10 classes
            nn.log_softmax,  # For numerical stability
        ]

    def __call__(self, x):
        """
        Forward pass through the network.
        """
        for layer in self.layers:
            x = layer(x)
        return x

    def predict(self, x):
        """
        Generate class predictions from model outputs.
        """
        return jnp.argmax(self(x), axis=-1)

__call__(x)

Forward pass through the network.

Source code in examples/dl/cnn_mnist.py
def __call__(self, x):
    """
    Forward pass through the network.
    """
    for layer in self.layers:
        x = layer(x)
    return x

__init__()

Initialize the CNN architecture with predefined layers.

Source code in examples/dl/cnn_mnist.py
def __init__(self):
    """Initialize the CNN architecture with predefined layers."""
    self.layers = [
        nn.Conv2d(1, 4, kernel_size=4, key=key1),  # First conv layer: 1->4 channels
        nn.MaxPool2d(2, 2),  # Reduce spatial dimensions
        nn.relu,  # Activation function
        nn.Conv2d(
            4, 8, kernel_size=4, key=key1
        ),  # Second conv layer: 4->8 channels
        nn.MaxPool2d(2, 2),  # Further reduce dimensions
        nn.relu,  # Activation function
        jnp.ravel,  # Flatten for dense layer
        nn.Linear(8 * 4 * 4, 10, key=key2),  # Output layer: 10 classes
        nn.log_softmax,  # For numerical stability
    ]

predict(x)

Generate class predictions from model outputs.

Source code in examples/dl/cnn_mnist.py
def predict(self, x):
    """
    Generate class predictions from model outputs.
    """
    return jnp.argmax(self(x), axis=-1)