Skip to content

MLP

Module: mlp.py

This module implements Multi-Layer Perceptron (MLP) neural network architecture.

Key Classes

MLP: Multi-Layer Perceptron

Key Features
  • Built on top of base class getting all its functionalities
  • Efficient neural networks implementation using equinox modules
Authors
  • Lokesh Mohanty (lokeshm@iisc.ac.in)
Version Info
  • 02/01/2025: Initial version

MLP

Bases: FCNN

Multi-Layer Perceptron

Source code in scirex/core/dl/mlp.py
class MLP(FCNN):
    """
    Multi-Layer Perceptron
    """

    def __init__(
        self,
        in_size: int,
        out_size: int,
        hidden_size: int = 0,
        depth: int = 0,
        activation: Callable = jax.nn.relu,
        final_activation: Callable = lambda x: x,
        random_seed: int = 0,
    ):
        """
        Constructor for Multi-Layer Perceptron

        Args:
            in_size: Input size
            out_size: Output size
            hidden_size: Hidden size
            depth: Depth of the network
            activation: Activation function
            final_activation: Final activation function
            random_seed: Random seed
        """
        key = jax.random.PRNGKey(random_seed)
        if depth == 0:
            self.layers = [eqx.nn.Linear(in_size, out_size, key=key)]
        else:
            self.layers = [eqx.nn.Linear(in_size, hidden_size, key=key), jax.nn.relu]
            for _ in range(depth - 1):
                self.layers += [
                    eqx.nn.Linear(hidden_size, hidden_size, key=key),
                    jax.nn.relu,
                ]
            self.layers += [eqx.nn.Linear(hidden_size, out_size, key=key)]

        self.layers += [final_activation]

__init__(in_size, out_size, hidden_size=0, depth=0, activation=jax.nn.relu, final_activation=lambda x: x, random_seed=0)

Constructor for Multi-Layer Perceptron

Parameters:

Name Type Description Default
in_size int

Input size

required
out_size int

Output size

required
hidden_size int

Hidden size

0
depth int

Depth of the network

0
activation Callable

Activation function

relu
final_activation Callable

Final activation function

lambda x: x
random_seed int

Random seed

0
Source code in scirex/core/dl/mlp.py
def __init__(
    self,
    in_size: int,
    out_size: int,
    hidden_size: int = 0,
    depth: int = 0,
    activation: Callable = jax.nn.relu,
    final_activation: Callable = lambda x: x,
    random_seed: int = 0,
):
    """
    Constructor for Multi-Layer Perceptron

    Args:
        in_size: Input size
        out_size: Output size
        hidden_size: Hidden size
        depth: Depth of the network
        activation: Activation function
        final_activation: Final activation function
        random_seed: Random seed
    """
    key = jax.random.PRNGKey(random_seed)
    if depth == 0:
        self.layers = [eqx.nn.Linear(in_size, out_size, key=key)]
    else:
        self.layers = [eqx.nn.Linear(in_size, hidden_size, key=key), jax.nn.relu]
        for _ in range(depth - 1):
            self.layers += [
                eqx.nn.Linear(hidden_size, hidden_size, key=key),
                jax.nn.relu,
            ]
        self.layers += [eqx.nn.Linear(hidden_size, out_size, key=key)]

    self.layers += [final_activation]