Bases: Network
Convolutional Neural Network for MNIST digit classification.
Architecture:
- Conv2D: 1->4 channels, 4x4 kernel
- MaxPool2D: 2x2 pooling
- ReLU activation
- Conv2D: 4->8 channels, 4x4 kernel
- MaxPool2D: 2x2 pooling
- ReLU activation
- Flatten
- Dense: 844->10 units
- LogSoftmax activation
Source code in examples/dl/cnn_mnist.py
| class CNN(Network):
"""
Convolutional Neural Network for MNIST digit classification.
Architecture:
- Conv2D: 1->4 channels, 4x4 kernel
- MaxPool2D: 2x2 pooling
- ReLU activation
- Conv2D: 4->8 channels, 4x4 kernel
- MaxPool2D: 2x2 pooling
- ReLU activation
- Flatten
- Dense: 8*4*4->10 units
- LogSoftmax activation
"""
layers: list
def __init__(self):
"""Initialize the CNN architecture with predefined layers."""
self.layers = [
nn.Conv2d(1, 4, kernel_size=4, key=key1), # First conv layer: 1->4 channels
nn.MaxPool2d(2, 2), # Reduce spatial dimensions
nn.relu, # Activation function
nn.Conv2d(
4, 8, kernel_size=4, key=key1
), # Second conv layer: 4->8 channels
nn.MaxPool2d(2, 2), # Further reduce dimensions
nn.relu, # Activation function
jnp.ravel, # Flatten for dense layer
nn.Linear(8 * 4 * 4, 10, key=key2), # Output layer: 10 classes
nn.log_softmax, # For numerical stability
]
def __call__(self, x):
"""
Forward pass through the network.
"""
for layer in self.layers:
x = layer(x)
return x
def predict(self, x):
"""
Generate class predictions from model outputs.
"""
return jnp.argmax(self(x), axis=-1)
|
__call__(x)
Forward pass through the network.
Source code in examples/dl/cnn_mnist.py
| def __call__(self, x):
"""
Forward pass through the network.
"""
for layer in self.layers:
x = layer(x)
return x
|
__init__()
Initialize the CNN architecture with predefined layers.
Source code in examples/dl/cnn_mnist.py
| def __init__(self):
"""Initialize the CNN architecture with predefined layers."""
self.layers = [
nn.Conv2d(1, 4, kernel_size=4, key=key1), # First conv layer: 1->4 channels
nn.MaxPool2d(2, 2), # Reduce spatial dimensions
nn.relu, # Activation function
nn.Conv2d(
4, 8, kernel_size=4, key=key1
), # Second conv layer: 4->8 channels
nn.MaxPool2d(2, 2), # Further reduce dimensions
nn.relu, # Activation function
jnp.ravel, # Flatten for dense layer
nn.Linear(8 * 4 * 4, 10, key=key2), # Output layer: 10 classes
nn.log_softmax, # For numerical stability
]
|
predict(x)
Generate class predictions from model outputs.
Source code in examples/dl/cnn_mnist.py
| def predict(self, x):
"""
Generate class predictions from model outputs.
"""
return jnp.argmax(self(x), axis=-1)
|