Fourier Neural Operator (FNO) - Training Tutorial - 1D Advection Problem
Overview
This tutorial demonstrates implementing a Fourier Neural Operator (FNO) to solve the 1D advection equation using JAX and Equinox.
Mathematical Formulation
The 1D advection equation:
\[
\frac{\partial u}{\partial t} + v\frac{\partial u}{\partial x} = 0
\]
where: - \(u(x,t)\): solution (wave amplitude) - \(v\): wave velocity - \(x\): spatial coordinate - \(t\): time
Initial and Boundary Conditions
- Initial condition: \(u(x,0) = u_0(x)\)
- Periodic boundary conditions: \(u(0,t) = u(2\pi,t)\)
Implementation Parameters
Domain Parameters
nx = 64 # Spatial resolution
nt = 100 # Temporal resolution
L = 2 * π # Domain length
T = 2.0 # Final time
v = 1.0 # Wave velocity
Dataset Parameters
FNO Architecture
model = FNO1d(
in_channels=2,
out_channels=1,
modes=16,
width=64,
activation=jax.nn.gelu,
n_blocks=4
)
Code Components
Data Generation
def generate_advection_data(n_samples=1200, nx=64, nt=100, v=1.0):
"""Generate data for the 1D advection equation"""
# Domain setup
L = 2 * jnp.pi
dx = L / nx
x = jnp.linspace(0, L, nx)
# Time domain
T = 2.0
dt = T / nt
t = jnp.linspace(0, T, nt)
Initial Condition Generator
def generate_initial_condition(key):
"""Generate random sinusoidal initial conditions"""
k1, k2 = jax.random.split(key)
n_waves = 3
amplitudes = jax.random.uniform(k1, (n_waves,), minval=0.1, maxval=1.0)
frequencies = jax.random.randint(k2, (n_waves,), minval=1, maxval=4)
Advection Solver
def solve_advection(u0):
"""Solve using upwind scheme"""
u = jnp.zeros((nt, nx))
u = u.at[0].set(u0)
for n in range(1, nt):
if v > 0:
u = u.at[n].set(
u[n-1] - v * dt/dx * (u[n-1] - jnp.roll(u[n-1], 1))
)
Training Loop
@eqx.filter_jit
def make_step(model, opt_state, batch):
"""Single training step"""
def loss_fn(model):
pred = jax.vmap(model)(batch[0])
return jnp.mean((pred - batch[1])**2)
loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
updates, opt_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
Results Visualization
# Example prediction
plt.figure()
plt.plot(spatial_grid, test_x[0, 0], label='Initial condition')
plt.plot(spatial_grid, test_y[0, 0], label='True solution')
plt.plot(spatial_grid, test_pred[0, 0], '--', label='FNO prediction')
# Training loss
plt.figure()
plt.semilogy(losses)
# Error analysis
plt.figure()
plt.plot(spatial_grid, jnp.abs(test_pred[0, 0] - test_y[0, 0]))