Skip to content

gcn_cora

Example Script: vae-mnist.py

This script demonstrates how to use the neural network implementation from the SciREX library to variational auto-encoder on the MNIST dataset.

This example includes
  • Loading the MNIST dataset using tensorflow.keras.datasets
  • Training Variational Auto-Encoder
  • Evaluating and visualizing the results
Key features
  • Encoder-decoder architecture
  • Latent space sampling with reparameterization trick
  • KL divergence loss for regularization
  • Model checkpointing and visualization
Authors
  • Lokesh Mohanty (lokeshm@iisc.ac.in)
Version Info
  • 06/01/2024: Initial version

VAE

Bases: Network

Variational Autoencoder implementation.

The VAE consists of: 1. An encoder that maps inputs to latent space parameters 2. A sampling layer that uses the reparameterization trick 3. A decoder that reconstructs inputs from latent samples

Attributes:

Name Type Description
encoder FCNN

Neural network for encoding

decoder FCNN

Neural network for decoding

Source code in examples/dl/vae_mnist.py
class VAE(Network):
    """
    Variational Autoencoder implementation.

    The VAE consists of:
    1. An encoder that maps inputs to latent space parameters
    2. A sampling layer that uses the reparameterization trick
    3. A decoder that reconstructs inputs from latent samples

    Attributes:
        encoder (FCNN): Neural network for encoding
        decoder (FCNN): Neural network for decoding
    """

    encoder: FCNN
    decoder: FCNN

    def __init__(self, encoderLayers, decoderLayers):
        """
        Initialize VAE with encoder and decoder architectures.

        Args:
            encoderLayers (list): Layer definitions for encoder
            decoderLayers (list): Layer definitions for decoder
        """
        self.encoder = FCNN(encoderLayers)
        self.decoder = FCNN(decoderLayers)

    def __call__(self, x):
        """
        Forward pass through the VAE.

        Args:
            x (jax.Array): Input image tensor

        Returns:
            jax.Array: Reconstructed image tensor
        """
        # Encode input to get latent parameters
        x = self.encoder(x)
        # Split into mean and log standard deviation
        mean, stddev = x[:-1], jnp.exp(x[-1])
        # Sample from latent space using reparameterization trick
        z = mean + stddev * jax.random.normal(jax.random.PRNGKey(0), mean.shape)
        # Decode latent sample
        return self.decoder(z)

__call__(x)

Forward pass through the VAE.

Parameters:

Name Type Description Default
x Array

Input image tensor

required

Returns:

Type Description

jax.Array: Reconstructed image tensor

Source code in examples/dl/vae_mnist.py
def __call__(self, x):
    """
    Forward pass through the VAE.

    Args:
        x (jax.Array): Input image tensor

    Returns:
        jax.Array: Reconstructed image tensor
    """
    # Encode input to get latent parameters
    x = self.encoder(x)
    # Split into mean and log standard deviation
    mean, stddev = x[:-1], jnp.exp(x[-1])
    # Sample from latent space using reparameterization trick
    z = mean + stddev * jax.random.normal(jax.random.PRNGKey(0), mean.shape)
    # Decode latent sample
    return self.decoder(z)

__init__(encoderLayers, decoderLayers)

Initialize VAE with encoder and decoder architectures.

Parameters:

Name Type Description Default
encoderLayers list

Layer definitions for encoder

required
decoderLayers list

Layer definitions for decoder

required
Source code in examples/dl/vae_mnist.py
def __init__(self, encoderLayers, decoderLayers):
    """
    Initialize VAE with encoder and decoder architectures.

    Args:
        encoderLayers (list): Layer definitions for encoder
        decoderLayers (list): Layer definitions for decoder
    """
    self.encoder = FCNN(encoderLayers)
    self.decoder = FCNN(decoderLayers)

loss_fn(output, y)

Compute KL divergence loss between output and target.

Parameters:

Name Type Description Default
output Array

Model output

required
y Array

Target values

required

Returns:

Name Type Description
float

Loss value

Source code in examples/dl/vae_mnist.py
def loss_fn(output, y):
    """
    Compute KL divergence loss between output and target.

    Args:
        output (jax.Array): Model output
        y (jax.Array): Target values

    Returns:
        float: Loss value
    """
    return jnp.abs(nn.kl_divergence(output.reshape(-1), y.reshape(-1)))