Fourier Neural Operator (FNO) - Training Tutorial - Burgers Equation
Mathematical Formulation
The 1D Burgers equation:
\[
\frac{\partial u}{\partial t} + u\frac{\partial u}{\partial x} = \nu\frac{\partial^2 u}{\partial x^2}
\]
where: - \(u(x,t)\): velocity field - \(\nu\): viscosity coefficient - \(x\): spatial coordinate - \(t\): time
Implementation Parameters
Domain Parameters
x_domain = [0, 2π]
spatial_resolution = 8192 # Full resolution
training_resolution = 256 # Subsampled for training
Dataset Parameters
FNO Architecture
fno = FNO1d(
in_channels=2, # Initial condition + spatial coordinate
out_channels=1, # Solution at t=1
modes=16, # Number of Fourier modes
width=64, # Channel width
activation=jax.nn.relu,
n_blocks=4
)
Implementation Steps
1. Data Loading
data = scipy.io.loadmat("burgers_data_R10.mat")
a, u = data["a"], data["u"] # Initial conditions and solutions
# Add channel dimension and mesh
a = a[:, jnp.newaxis, :]
u = u[:, jnp.newaxis, :]
mesh = jnp.linspace(0, 2 * jnp.pi, u.shape[-1])
2. Data Preprocessing
# Combine initial condition with mesh information
mesh_shape_corrected = jnp.repeat(mesh[jnp.newaxis, jnp.newaxis, :], u.shape[0], axis=0)
a_with_mesh = jnp.concatenate((a, mesh_shape_corrected), axis=1)
# Train-test split
train_x, test_x = a_with_mesh[:1000], a_with_mesh[1000:1200]
train_y, test_y = u[:1000], u[1000:1200]
3. Training Loop
@eqx.filter_jit
def make_step(model, state, x, y):
loss, grad = eqx.filter_value_and_grad(loss_fn)(model, x, y)
val_loss = loss_fn(model, test_x[..., ::32], test_y[..., ::32])
updates, new_state = optimizer.update(grad, state, model)
new_model = eqx.apply_updates(model, updates)
return new_model, new_state, loss, val_loss
4. Evaluation Metrics
def relative_l2_norm(pred, ref):
diff_norm = jnp.linalg.norm(pred - ref)
ref_norm = jnp.linalg.norm(ref)
return diff_norm / ref_norm
Results Analysis
The implementation generates five visualizations:
initial_vs_after.png
: Initial condition vs solution at t=1loss.png
: Training and validation loss curvesprediction.png
: Model prediction vs ground truthdifference.png
: Error analysissuperresolution.png
: Zero-shot superresolution capability
Key Performance Metrics
- Relative L2 error: ~1e-2
- Training time: 200 epochs
- Resolution invariance demonstrated through superresolution