Skip to content

FNO block

Module: fno_block.py

This module provides the implementation of a single block of the Fourier Neural Operator (FNO) model.

Classes:

Name Description
FNOBlock1d

Single block of the FNO model

Dependencies
  • jax: For array processing
  • equinox: For neural network layers
Key Features
  • Spectral convolution
  • Bypass convolution
  • Activation function
Authors

Diya Nag Chaudhury

Version Info

29/Dec/2024: Initial version - Diya Nag Chaudhury

References

None

FNOBlock1d

Bases: Module

A single block of the FNO model.

This block consists of a spectral convolution followed by a bypass convolution and an activation function.

Attributes: spectral_conv: SpectralConv1d bypass_conv: eqx.nn.Conv1d activation: Callable

Methods: init: Initializes the FNOBlock1d object call: Calls the FNOBlock1d object

Source code in scirex/core/sciml/fno/layers/fno_block.py
class FNOBlock1d(eqx.Module):
    """
    A single block of the FNO model.

    This block consists of a spectral convolution followed by a bypass convolution
    and an activation function.

    Attributes:
    spectral_conv: SpectralConv1d
    bypass_conv: eqx.nn.Conv1d
    activation: Callable

    Methods:
    __init__: Initializes the FNOBlock1d object
    __call__: Calls the FNOBlock1d object
    """

    spectral_conv: SpectralConv1d
    bypass_conv: eqx.nn.Conv1d
    activation: Callable

    def __init__(
        self,
        in_channels,
        out_channels,
        modes,
        activation,
        *,
        key,
    ):
        spectral_conv_key, bypass_conv_key = jax.random.split(key)
        self.spectral_conv = SpectralConv1d(
            in_channels,
            out_channels,
            modes,
            key=spectral_conv_key,
        )
        self.bypass_conv = eqx.nn.Conv1d(
            in_channels,
            out_channels,
            1,  # Kernel size is one
            key=bypass_conv_key,
        )
        self.activation = activation

    def __call__(
        self,
        x,
    ):
        return self.activation(self.spectral_conv(x) + self.bypass_conv(x))

Module: fno_block.py

This module provides the implementation of a single block of the Fourier Neural Operator (FNO) model.

Classes:

Name Description
FNOBlock1d

Single block of the FNO model

Dependencies
  • jax: For array processing
  • equinox: For neural network layers
Key Features
  • Spectral convolution
  • Bypass convolution
  • Activation function
Authors

Diya Nag Chaudhury

Version Info

29/Dec/2024: Initial version - Diya Nag Chaudhury

References

None

FNOBlock1d

Bases: Module

A single block of the FNO model.

This block consists of a spectral convolution followed by a bypass convolution and an activation function.

Attributes: spectral_conv: SpectralConv1d bypass_conv: eqx.nn.Conv1d activation: Callable

Methods: init: Initializes the FNOBlock1d object call: Calls the FNOBlock1d object

Source code in scirex/core/sciml/fno/layers/fno_block.py
class FNOBlock1d(eqx.Module):
    """
    A single block of the FNO model.

    This block consists of a spectral convolution followed by a bypass convolution
    and an activation function.

    Attributes:
    spectral_conv: SpectralConv1d
    bypass_conv: eqx.nn.Conv1d
    activation: Callable

    Methods:
    __init__: Initializes the FNOBlock1d object
    __call__: Calls the FNOBlock1d object
    """

    spectral_conv: SpectralConv1d
    bypass_conv: eqx.nn.Conv1d
    activation: Callable

    def __init__(
        self,
        in_channels,
        out_channels,
        modes,
        activation,
        *,
        key,
    ):
        spectral_conv_key, bypass_conv_key = jax.random.split(key)
        self.spectral_conv = SpectralConv1d(
            in_channels,
            out_channels,
            modes,
            key=spectral_conv_key,
        )
        self.bypass_conv = eqx.nn.Conv1d(
            in_channels,
            out_channels,
            1,  # Kernel size is one
            key=bypass_conv_key,
        )
        self.activation = activation

    def __call__(
        self,
        x,
    ):
        return self.activation(self.spectral_conv(x) + self.bypass_conv(x))