PINN Solver

Overview

The PINN (Physics-Informed Neural Network) solver provides an alternative to the finite-difference level set method. Instead of discretizing the PDE on a grid, a neural network \(\varphi_\theta(x, y, t)\) is trained to satisfy the Hamilton-Jacobi fire spread equation:

\[ \frac{\partial \varphi}{\partial t} + F(x,y,t)\lvert\nabla \varphi\rvert = 0 \]

where \(F\) is the spread rate from a FireSpreadModel or any callable \((t, x, y) \to R\).

Hard Initial Condition Constraint

The PINN uses a solution decomposition that exactly satisfies the initial condition:

\[ \tilde\varphi(x,y,t) = \frac{\text{IC}(x,y)}{L} + \tau(t) \cdot \text{NN}_\theta(x_n, y_n, t_n) \]

where \(\tau(t) = (t - t_{\min}) / (t_{\max} - t_{\min})\) is zero at \(t = t_{\min}\), ensuring \(\tilde\varphi(x,y,t_{\min}) = \text{IC}(x,y)/L\) exactly. This eliminates the need for an IC loss term and lets all training focus on learning the PDE dynamics.

Advantages

  • Mesh-free: Evaluate \(\varphi\) at any continuous \((x, y, t)\) – no grid interpolation needed
  • Continuous in time: Query the fire state at any \(t\) without stepping through intermediate time steps
  • Exact initial condition: Hard constraint guarantees perfect IC fit
  • Data assimilation ready: Optional observation loss term for incorporating satellite/sensor data

Requirements

The PINN solver is a package extension – it only loads when the ML dependencies are available:

using Wildfires
using Lux, ComponentArrays, ForwardDiff, Zygote, Optimization, OptimizationOptimisers

Quick Start

# Set up a fire scenario
grid = LevelSetGrid(20, 20, dx=50.0)
ignite!(grid, 500.0, 500.0, 80.0)

# Simple constant spread model for fast training
const_spread = (t, x, y) -> 5.0

# Configure PINN
config = PINNConfig(
    hidden_dims = [32, 32],
    n_interior = 500,
    n_boundary = 100,
    max_epochs = 3000,
    learning_rate = 1e-3,
    resample_every = 500,
)

# Train
sol = train_pinn(grid, const_spread, (0.0, 10.0);
                 config=config, rng=MersenneTwister(123), verbose=false)

sol
PINNSolution{Lux}(epochs=3001, final_loss=0.0001157)

Training Loss

fig = Figure(size=(600, 300))
ax = Axis(fig[1, 1], xlabel="Epoch", ylabel="Loss", yscale=log10,
    title="PINN Training Loss")
lines!(ax, sol.loss_history, color=:steelblue)
fig

PINN Predictions

The PINNSolution is callable – query \(\varphi\) at any point:

# Single point evaluation
sol(5.0, 500.0, 500.0)
-48.06566849824875

Evaluate on the full grid at a specific time with predict_on_grid:

fig = Figure(size=(800, 300))
for (col, t) in enumerate([0.0, 5.0, 10.0])
    φ = predict_on_grid(sol, grid, t)
    ax = Axis(fig[1, col], title="PINN t = $t min", aspect=DataAspect())
    heatmap!(ax, collect(xcoords(grid)), collect(ycoords(grid)), φ, colormap=:RdYlGn)
    contour!(ax, collect(xcoords(grid)), collect(ycoords(grid)), φ, levels=[0.0],
        color=:black, linewidth=2)
    hidedecorations!(ax)
end
fig

Or update a LevelSetGrid in place with predict_on_grid!:

grid_pinn = LevelSetGrid(20, 20, dx=50.0)
predict_on_grid!(grid_pinn, sol, 5.0)
grid_pinn
LevelSetGrid{Float64} 20×20 (t=5.0, burned=68/400)

Comparison with Finite Differences

The PINN solution can be compared against the standard finite-difference solver. Note that PINNs are an approximate method – accuracy improves with larger networks, more collocation points, and longer training:

# Finite-difference reference with constant F = 5.0 m/min
grid_fd = LevelSetGrid(20, 20, dx=50.0)
ignite!(grid_fd, 500.0, 500.0, 80.0)
F = fill(5.0, size(grid_fd))
for _ in 1:20
    advance!(grid_fd, F, 0.5)
end

# PINN prediction at the same time
grid_pinn = LevelSetGrid(20, 20, dx=50.0)
predict_on_grid!(grid_pinn, sol, 10.0)

fig = Figure(size=(700, 300))

ax1 = Axis(fig[1, 1], title="Finite Differences (t=10)", aspect=DataAspect())
fireplot!(ax1, grid_fd)
hidedecorations!(ax1)

ax2 = Axis(fig[1, 2], title="PINN (t=10)", aspect=DataAspect())
fireplot!(ax2, grid_pinn)
hidedecorations!(ax2)

fig

Configuration

PINNConfig controls all training hyperparameters:

Parameter Default Description
hidden_dims [64, 64, 64] Hidden layer sizes
activation :tanh Activation function
n_interior 5000 PDE collocation points
n_boundary 500 Boundary condition points
lambda_pde 1.0 PDE loss weight
lambda_bc 1.0 BC loss weight
lambda_data 1.0 Data loss weight
learning_rate 1e-3 Adam learning rate
max_epochs 10000 Maximum training epochs
resample_every 500 Resample collocation points every N epochs
lbfgs_epochs 0 L-BFGS refinement iterations after Adam (0 = disabled)
importance_sampling false Concentrate collocation points near fire front
float32 false Use Float32 for NN weights (halves memory)

Tuning Hyperparameters

Network size (hidden_dims): Start with [32, 32] for quick experiments and scale up to [64, 64, 64] or larger for production. Wider/deeper networks fit complex fire fronts better but train slower. If the loss plateaus early, try increasing network capacity.

Collocation points (n_interior, n_boundary): More points improve accuracy but slow each epoch. A good rule of thumb: use at least 10x the number of grid cells for n_interior. n_boundary can be ~10% of n_interior.

Learning rate: The default 1e-3 works well with Adam for most cases. If the loss oscillates, reduce to 1e-4. If training is very slow, try 3e-3 briefly.

Loss weights (lambda_pde, lambda_bc, lambda_data): The defaults of 1.0 for all terms are a reasonable starting point. If the fire front bleeds through boundaries, increase lambda_bc (e.g., 10.0). When using observations, increase lambda_data (e.g., 5.010.0) to tighten the fit to observed perimeters.

Resampling (resample_every): Periodically resampling collocation points prevents overfitting to a fixed set of training locations. The default of 500 epochs works well; lower values (e.g., 100) can help if the loss stagnates.

L-BFGS refinement (lbfgs_epochs): After Adam converges to the right neighborhood, L-BFGS can polish the solution with second-order curvature information. Typical values are 100–500 iterations. Requires passing the optimizer explicitly:

using OptimizationOptimJL

config = PINNConfig(max_epochs=5000, lbfgs_epochs=200)
sol = train_pinn(grid, model, tspan;
                 config=config, lbfgs_optimizer=OptimizationOptimJL.LBFGS())

Importance sampling (importance_sampling): When enabled, half the collocation points are sampled uniformly and half are concentrated near the initial fire front (where \(|\varphi|\) is small). This focuses training capacity on the region that matters most. Enable it when the loss plateaus or the fire front is poorly resolved relative to the domain size.

Float32 weights (float32): Halves NN memory and can speed up forward/backward passes. Collocation points remain Float64 for spatial accuracy. Try this when memory is tight or training is slow. May reduce precision slightly.

Loss Function

The total loss is a weighted sum of three terms:

\[ \mathcal{L} = \lambda_{\text{pde}} \mathcal{L}_{\text{pde}} + \lambda_{\text{bc}} \mathcal{L}_{\text{bc}} + \lambda_{\text{data}} \mathcal{L}_{\text{data}} \]

  • PDE residual (\(\mathcal{L}_{\text{pde}}\)): Enforces the Hamilton-Jacobi equation at random interior points
  • Boundary condition (\(\mathcal{L}_{\text{bc}}\)): Penalizes burned regions (\(\varphi < 0\)) at the domain boundary
  • Data loss (\(\mathcal{L}_{\text{data}}\)): Optional – fits to observed fire perimeter data

The initial condition is enforced exactly through the hard constraint decomposition (no IC loss term needed).

Data Assimilation

Pass observations as a tuple of vectors (t, x, y, phi) to incorporate fire perimeter data:

# Example: observations at t=10 along a known fire boundary
obs_t = fill(10.0, 50)
obs_x = rand(400.0:800.0, 50)
obs_y = rand(400.0:800.0, 50)
obs_phi = zeros(50)  # on the fire front (phi = 0)

sol = train_pinn(grid, model, (0.0, 20.0);
                 observations=(obs_t, obs_x, obs_y, obs_phi))

Observation format details:

  • t: Time in minutes (same units as the simulation time span)
  • x, y: Spatial coordinates in meters (same coordinate system as the LevelSetGrid)
  • phi: Target level set value. Use 0.0 for points on the fire front, negative for burned, positive for unburned.
  • Observations at multiple times are supported — just vary the t values.
  • As few as 20–50 well-placed perimeter points can significantly improve the solution. Points along the fire front (phi = 0) are most informative.

References

  • Raissi, M., Perdikaris, P., & Karniadakis, G. E. (2019). Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations. Journal of Computational Physics, 378, 686-707.
  • Osher, S. & Sethian, J.A. (1988). Fronts propagating with curvature-dependent speed. J. Computational Physics, 79(1), 12-49.