vae_mnist
Example Script: vae-mnist.py
This script demonstrates how to use the neural network implementation from the SciREX library to variational auto-encoder on the MNIST dataset.
This example includes
- Loading the MNIST dataset using tensorflow.keras.datasets
- Training Variational Auto-Encoder
- Evaluating and visualizing the results
Key features
- Encoder-decoder architecture
- Latent space sampling with reparameterization trick
- KL divergence loss for regularization
- Model checkpointing and visualization
Version Info
- 06/01/2024: Initial version
VAE
Bases: Network
Variational Autoencoder implementation.
The VAE consists of: 1. An encoder that maps inputs to latent space parameters 2. A sampling layer that uses the reparameterization trick 3. A decoder that reconstructs inputs from latent samples
Attributes:
Name | Type | Description |
---|---|---|
encoder |
FCNN
|
Neural network for encoding |
decoder |
FCNN
|
Neural network for decoding |
Source code in examples/dl/vae_mnist.py
__call__(x)
Forward pass through the VAE.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Array
|
Input image tensor |
required |
Returns:
Type | Description |
---|---|
jax.Array: Reconstructed image tensor |
Source code in examples/dl/vae_mnist.py
__init__(encoderLayers, decoderLayers)
Initialize VAE with encoder and decoder architectures.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
encoderLayers
|
list
|
Layer definitions for encoder |
required |
decoderLayers
|
list
|
Layer definitions for decoder |
required |
Source code in examples/dl/vae_mnist.py
loss_fn(output, y)
Compute KL divergence loss between output and target.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output
|
Array
|
Model output |
required |
y
|
Array
|
Target values |
required |
Returns:
Name | Type | Description |
---|---|---|
float |
Loss value |