Skip to content

Fourier Neural Operator (FNO) - Wave Equation Tutorial

Mathematical Formulation

The wave equation:

\[ \frac{\partial^2 u}{\partial t^2} = c^2\frac{\partial^2 u}{\partial x^2} \]

where: - \(u(x,t)\): displacement field - \(c\): wave speed - \(x\): spatial coordinate - \(t\): time

Parameters

Domain Parameters

nx = 64  # Spatial points
nt = 100  # Time steps
L = 2π  # Domain length
T = 2.0  # Final time
c = 1.0  # Wave speed

Architecture Parameters

model = FNO1d(
    in_channels=3,  # Initial displacement + velocity + coordinate
    out_channels=1,  # Final displacement
    modes=16,
    width=64,
    activation=jax.nn.gelu,
    n_blocks=4
)

Implementation

Initial Conditions Generator

def generate_initial_conditions(key):
    # Initial displacement (Gaussian pulses)
    max_pulses = 2
    positions = jax.random.uniform(k1, (max_pulses,)) * L
    widths = jax.random.uniform(k2, (max_pulses,)) * 0.2 + 0.1
    amplitudes = jax.random.uniform(k3, (max_pulses,)) * 0.8 + 0.2

    # Initial velocity (smoothed random)
    v0 = jax.random.normal(k4, (nx,)) * 0.1
    v0 = jnp.convolve(v0, jnp.ones(10)/10, mode='same')

    return u0, v0

Wave Equation Solver

def solve_wave_equation(init_conditions):
    """Central differences scheme"""
    u0, v0 = init_conditions
    u = jnp.zeros((nt, nx))
    u = u.at[0].set(u0)
    u = u.at[1].set(u0 + dt * v0)

    for n in range(2, nt):
        u = u.at[n].set(
            2 * u[n-1] - u[n-2] + 
            (c * dt/dx)**2 * (
                jnp.roll(u[n-1], 1) - 2*u[n-1] + jnp.roll(u[n-1], -1)
            )
        )
    return u

Training Loop

@eqx.filter_jit
def make_step(model, opt_state, batch):
    def loss_fn(model):
        pred = jax.vmap(model)(batch[0])
        return jnp.mean((pred - batch[1])**2)

    loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

Output Visualization

outputs/fno/wave/
├── example_prediction.png  # Initial conditions, true solution, prediction
├── training_loss.png      # MSE loss evolution
└── absolute_error.png     # Point-wise absolute error

Key Metrics

  • MSE Loss tracking
  • Test set error evaluation
  • Point-wise error analysis

Features

  • Handles coupled initial conditions (displacement and velocity)
  • Multi-component input processing
  • Periodic boundary conditions
  • Second-order accuracy in time and space