Training with SGDBackend

This tutorial demonstrates how to use the SGDBackend for training a hybrid autoregressive model modelling the classic hare-lynx predator-prey system, where the predation interaction is learned via neural networks while maintaining mechanistic constraints for birth and death processes.

Note

The train function provided by HybridDynamicModels is an experimental feature, exposed for demonstration purposes. Users are encouraged to implement their own train function to gain more control over the training process; see Overloading the train function.

Importing necessary packages

In order to use the SGDBackend, we'll need to manually load Lux, Optimisers, and ComponentArrays. We additionally load Zygote for automatic differentiation, ParameterSchedulers for learning rate scheduling, and Plots, DataFrames, DelimitedFiles, and HTTP for data handling and visualization.

using Lux, Optimisers, ComponentArrays
using Zygote
using HybridDynamicModels
using ParameterSchedulers
using Random
using Plots
using DataFrames, DelimitedFiles, HTTP

const luxtype = Lux.f64
f64 (generic function with 1 method)

Data loading

Load the Lynx-Hare population dataset:

url = "http://people.whitman.edu/~hundledr/courses/M250F03/LynxHare.txt"
data = readdlm(IOBuffer(HTTP.get(url).body), ' ') |> luxtype
df_data = DataFrame(Year = data[:, 1], Hare = data[:, 2], Lynx = data[:, 3])

# Visualize observed data (hare and lynx)
plt_data = plot(df_data.Year, df_data.Hare, label = "Hare", xlabel = "Year",
    ylabel = "Population", title = "Observed Hare-Lynx Data")
plot!(plt_data, df_data.Year, df_data.Lynx, label = "Lynx")
display(plt_data)

Data preparation

Prepare training and test datasets:

tsteps = Vector(df_data.Year) |> luxtype

# Extract hare and lynx data
hare_lynx_data = Array(df_data[:, Not(:Year)])' |> luxtype
hare_lynx_data ./= maximum(hare_lynx_data)

# Data array: [hare, lynx]
data_array = hare_lynx_data |> luxtype

forecast_length = 20
test_idx = size(data_array, 2) - forecast_length + 1:size(data_array, 2)

# Create training dataloader
dataloader_train = SegmentedTimeSeries(
    (data_array[:, Not(test_idx)], tsteps[Not(test_idx)]);
    segment_length = 4, shift = 2, batchsize = 20)
SegmentedTimeSeries
  Time series length: 71
  Segment length: 4
  Shift: 2 (50.0% overlap)
  Batch size: 20
  Total segments: 34
  Total batches: 1

Model definition

Define a hare-lynx predator-prey model where the predation interaction is learned via neural networks, while birth and death processes follow mechanistic rules:

# Neural network for hare-lynx predation interactions
hlsize = 2^4
neural_interactions = Chain(Dense(2, hlsize, relu),
                        Dense(hlsize, hlsize, relu),
                        Dense(hlsize, 1))  # Output: predation rate

# Learnable ecological parameters
mechanistic_params = ParameterLayer(init_value = (
                                    hare_birth = [0.8],
                                    hare_death = [0.1],
                                    lynx_death = [0.2] ), 
                                    constraint = NamedTupleConstraint((hare_birth = BoxConstraint([0.0], [2.0]),
                                                                       hare_death = BoxConstraint([0.001], [1.0]),
                                                                       lynx_death = BoxConstraint([0.001], [1.0]))
                                ))

# Hybrid ecosystem dynamics
function ecosystem_step(layers, u, ps, t)
    hare, lynx = max.(u, 0.)  # Unpack state variables
    
    params = layers.mechanistic_params(ps.mechanistic_params)
    
    # Neural network: predation rate
    predation_input = [hare, lynx]
    predation_rate = layers.neural_interactions(predation_input, ps.neural_interactions)[1]
    
    # Mechanistic hare dynamics
    hare_birth = params.hare_birth[1] * hare
    hare_predation = -predation_rate * hare * lynx
    hare_natural_death = -params.hare_death[1] * hare
    
    # Mechanistic lynx dynamics
    lynx_predation_gain = predation_rate * hare * lynx  # Lynx gain from predation
    lynx_death = -params.lynx_death[1] * lynx
    
    # Return derivatives
    return [
        hare_birth + hare_predation + hare_natural_death,  # Hare
        lynx_predation_gain + lynx_death                   # Lynx
    ]
end

# Create autoregressive model
model = ARModel(
    (;neural_interactions, mechanistic_params),
    ecosystem_step;
    dt = tsteps[2] - tsteps[1],
);

Training configuration

Configure training with learning rate scheduling and callbacks:

# Learning rate schedule: exponential decay
lr_schedule = Step(1e-2, 0.9, 200)

# Callback for monitoring and learning rate adjustment
function callback(loss, epoch, ts)
    if epoch % 20 == 0
        current_lr = lr_schedule(epoch)
        @info "Epoch $epoch: Loss = $loss, LR = $current_lr"
        Optimisers.adjust!(ts.optimizer_state, current_lr)
    end
end

# Training backend configuration
backend = SGDBackend(
    AdamW(eta = 1e-2, lambda = 1e-4),  # Optimizer with weight decay
    2000,                             # Number of epochs
    AutoZygote(),                     # Automatic differentiation
    MSELoss(),                        # Loss function
    callback                          # Training callback
)
HybridDynamicModelsLuxExt.SGDBackend(AdamW(eta=0.01, beta=(0.9, 0.999), lam
bda=0.0001, epsilon=1.0e-8, couple=true), 2000, AutoZygote(), GenericLossFu
nction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}(Lux.Los
sFunctionImpl.l2_distance_loss, Statistics.mean), Main.var"##WeaveSandBox#2
37".callback)

Training

Train the model with initial condition inference:

@info "Starting training..."
result = train(backend, model, dataloader_train, InferICs(true));
Starting training...
Epoch 20: Loss = 0.03873773946887725, LR = 0.01
Epoch 40: Loss = 0.027979128959718824, LR = 0.01
Epoch 60: Loss = 0.022409160359493113, LR = 0.01
Epoch 80: Loss = 0.018925108275828755, LR = 0.01
Epoch 100: Loss = 0.016663934203832453, LR = 0.01
Epoch 120: Loss = 0.014817495316206603, LR = 0.01
Epoch 140: Loss = 0.01356381551281366, LR = 0.01
Epoch 160: Loss = 0.01269433585602035, LR = 0.01
Epoch 180: Loss = 0.0121249254682, LR = 0.01
Epoch 200: Loss = 0.01175334403578352, LR = 0.01
Epoch 220: Loss = 0.011506516928238637, LR = 0.009000000000000001
Epoch 240: Loss = 0.011328215335286376, LR = 0.009000000000000001
Epoch 260: Loss = 0.011210211106416299, LR = 0.009000000000000001
Epoch 280: Loss = 0.011114126196432757, LR = 0.009000000000000001
Epoch 300: Loss = 0.011041389525203397, LR = 0.009000000000000001
Epoch 320: Loss = 0.010977631405769572, LR = 0.009000000000000001
Epoch 340: Loss = 0.010924746045140836, LR = 0.009000000000000001
Epoch 360: Loss = 0.010854294104474357, LR = 0.009000000000000001
Epoch 380: Loss = 0.010811675976299234, LR = 0.009000000000000001
Epoch 400: Loss = 0.010827760959276966, LR = 0.009000000000000001
Epoch 420: Loss = 0.010748585264812035, LR = 0.008100000000000001
Epoch 440: Loss = 0.010715101099053773, LR = 0.008100000000000001
Epoch 460: Loss = 0.010687222062239553, LR = 0.008100000000000001
Epoch 480: Loss = 0.010656570000713796, LR = 0.008100000000000001
Epoch 500: Loss = 0.010676247069779784, LR = 0.008100000000000001
Epoch 520: Loss = 0.010612433673517996, LR = 0.008100000000000001
Epoch 540: Loss = 0.010716463204923645, LR = 0.008100000000000001
Epoch 560: Loss = 0.010582059609425704, LR = 0.008100000000000001
Epoch 580: Loss = 0.010555874315144423, LR = 0.008100000000000001
Epoch 600: Loss = 0.010535615417614257, LR = 0.008100000000000001
Epoch 620: Loss = 0.010611595236873762, LR = 0.007290000000000001
Epoch 640: Loss = 0.010547863299101373, LR = 0.007290000000000001
Epoch 660: Loss = 0.010516973078894359, LR = 0.007290000000000001
Epoch 680: Loss = 0.010496783442717683, LR = 0.007290000000000001
Epoch 700: Loss = 0.010471417323718524, LR = 0.007290000000000001
Epoch 720: Loss = 0.010459783061899647, LR = 0.007290000000000001
Epoch 740: Loss = 0.010447165230885971, LR = 0.007290000000000001
Epoch 760: Loss = 0.010440231276687317, LR = 0.007290000000000001
Epoch 780: Loss = 0.010445250783760084, LR = 0.007290000000000001
Epoch 800: Loss = 0.010426324614561581, LR = 0.007290000000000001
Epoch 820: Loss = 0.010604130188737461, LR = 0.006561
Epoch 840: Loss = 0.010428667651899261, LR = 0.006561
Epoch 860: Loss = 0.010383217053634384, LR = 0.006561
Epoch 880: Loss = 0.010372620287535769, LR = 0.006561
Epoch 900: Loss = 0.010359356210476314, LR = 0.006561
Epoch 920: Loss = 0.010381483544722323, LR = 0.006561
Epoch 940: Loss = 0.010417505759747314, LR = 0.006561
Epoch 960: Loss = 0.010368541610752093, LR = 0.006561
Epoch 980: Loss = 0.010362711800477781, LR = 0.006561
Epoch 1000: Loss = 0.010325603798168352, LR = 0.006561
Epoch 1020: Loss = 0.010316569642107419, LR = 0.005904900000000001
Epoch 1040: Loss = 0.010294112562181528, LR = 0.005904900000000001
Epoch 1060: Loss = 0.010303497583659636, LR = 0.005904900000000001
Epoch 1080: Loss = 0.010269862771782146, LR = 0.005904900000000001
Epoch 1100: Loss = 0.010280936915460545, LR = 0.005904900000000001
Epoch 1120: Loss = 0.01047574168487166, LR = 0.005904900000000001
Epoch 1140: Loss = 0.010273516203555579, LR = 0.005904900000000001
Epoch 1160: Loss = 0.010238538178947359, LR = 0.005904900000000001
Epoch 1180: Loss = 0.01023078775392528, LR = 0.005904900000000001
Epoch 1200: Loss = 0.01022328785943326, LR = 0.005904900000000001
Epoch 1220: Loss = 0.010208889399069877, LR = 0.00531441
Epoch 1240: Loss = 0.010196224927603203, LR = 0.00531441
Epoch 1260: Loss = 0.010187362041633853, LR = 0.00531441
Epoch 1280: Loss = 0.01018491324454769, LR = 0.00531441
Epoch 1300: Loss = 0.010202709220879066, LR = 0.00531441
Epoch 1320: Loss = 0.010183335586376466, LR = 0.00531441
Epoch 1340: Loss = 0.010302667929520674, LR = 0.00531441
Epoch 1360: Loss = 0.010196505653380105, LR = 0.00531441
Epoch 1380: Loss = 0.010166814283950555, LR = 0.00531441
Epoch 1400: Loss = 0.010168618083352222, LR = 0.00531441
Epoch 1420: Loss = 0.010148112460842292, LR = 0.004782969000000001
Epoch 1440: Loss = 0.010136957351447682, LR = 0.004782969000000001
Epoch 1460: Loss = 0.010144556036410901, LR = 0.004782969000000001
Epoch 1480: Loss = 0.010130710156542422, LR = 0.004782969000000001
Epoch 1500: Loss = 0.010116020633853567, LR = 0.004782969000000001
Epoch 1520: Loss = 0.01012517654879442, LR = 0.004782969000000001
Epoch 1540: Loss = 0.010141771974832289, LR = 0.004782969000000001
Epoch 1560: Loss = 0.010112986928497634, LR = 0.004782969000000001
Epoch 1580: Loss = 0.010121697194996006, LR = 0.004782969000000001
Epoch 1600: Loss = 0.010095370739226716, LR = 0.004782969000000001
Epoch 1620: Loss = 0.01008619396774289, LR = 0.004304672100000001
Epoch 1640: Loss = 0.010086602768547514, LR = 0.004304672100000001
Epoch 1660: Loss = 0.01007847832736047, LR = 0.004304672100000001
Epoch 1680: Loss = 0.010125348766076119, LR = 0.004304672100000001
Epoch 1700: Loss = 0.010078779028739831, LR = 0.004304672100000001
Epoch 1720: Loss = 0.010081668635631882, LR = 0.004304672100000001
Epoch 1740: Loss = 0.010060679702467022, LR = 0.004304672100000001
Epoch 1760: Loss = 0.010062449142988109, LR = 0.004304672100000001
Epoch 1780: Loss = 0.01006080967449304, LR = 0.004304672100000001
Epoch 1800: Loss = 0.010078060426619174, LR = 0.004304672100000001
Epoch 1820: Loss = 0.010052871156739526, LR = 0.003874204890000001
Epoch 1840: Loss = 0.010044562787516564, LR = 0.003874204890000001
Epoch 1860: Loss = 0.010041155926342082, LR = 0.003874204890000001
Epoch 1880: Loss = 0.01002997261334782, LR = 0.003874204890000001
Epoch 1900: Loss = 0.010030611816642528, LR = 0.003874204890000001
Epoch 1920: Loss = 0.01004132036491225, LR = 0.003874204890000001
Epoch 1940: Loss = 0.010026171640353196, LR = 0.003874204890000001
Epoch 1960: Loss = 0.010052227517796365, LR = 0.003874204890000001
Epoch 1980: Loss = 0.010027353914943433, LR = 0.003874204890000001
Epoch 2000: Loss = 0.010032788673307893, LR = 0.003874204890000001

Results visualization

Visualize training fit and test predictions for the hare-lynx ecosystem:


# Colors: blue for hare, red for lynx
hare_color = "#ffd166"
lynx_color = "#ef476f"

# Function to plot training results
function plot_training_results(dataloader, result, model)
    plt = plot(title = "Training Results", xlabel = "Year",
        ylabel = "Population", legend = :topright)

    dataloader_tokenized = tokenize(dataloader)

    for tok in tokens(dataloader_tokenized)
        segment_data, segment_tsteps = dataloader_tokenized[tok]
        ics = result.ics[tok].u0

        pred, _ = model(
            (; u0 = ics, saveat = segment_tsteps,
                tspan = (segment_tsteps[1], segment_tsteps[end])),
            result.ps, result.st)

        # Plot observed data
        scatter!(plt, segment_tsteps, segment_data[1, :],
            label = (tok == 1 ? "Hare Data" : ""),
            color = hare_color, markersize = 4, alpha = 0.7)
        scatter!(plt, segment_tsteps, segment_data[2, :],
            label = (tok == 1 ? "Lynx Data" : ""),
            color = lynx_color, markersize = 4, alpha = 0.7)

        # Plot predictions
        plot!(plt, segment_tsteps, pred[1, :],
            label = (tok == 1 ? "Hare Predicted" : ""),
            color = hare_color, linewidth = 2)
        plot!(plt, segment_tsteps, pred[2, :],
            label = (tok == 1 ? "Lynx Predicted" : ""),
            color = lynx_color, linewidth = 2)
    end
    return plt
end

# Plot training results
plt_train = plot_training_results(dataloader_train, result, model)

Forecast on test data:

tsteps_test = tsteps[test_idx]
data_test = data_array[:, test_idx]
u0, t0 = result.ics[end]

preds, _ = model((; u0 = u0, tspan = (t0, tsteps_test[end]), saveat = tsteps_test),
                result.ps, result.st)

# Plot test predictions
plt_test = plot(title = "Test Predictions", xlabel = "Year", ylabel = "Population", legend = :topright)
scatter!(plt_test, tsteps_test, data_test[1, :], label = "Hare Data", color = hare_color, markersize = 4, alpha = 0.7)
scatter!(plt_test, tsteps_test, data_test[2, :], label = "Lynx Data", color = lynx_color, markersize = 4, alpha = 0.7)
plot!(plt_test, tsteps_test, preds[1, :], label = "Hare Predicted", color = hare_color, linewidth = 2)
plot!(plt_test, tsteps_test, preds[2, :], label = "Lynx Predicted", color = lynx_color, linewidth = 2)

Some final notes

  • When training a neural network-based parametrization, it is usually best practice to use a validation loss to avoid overfitting. This can be implemented by creating a separate validation dataloader (see create_train_val_loaders) and modifying the training loop to compute validation loss at intervals, overloading the train function. Check out the Overloading the train function tutorial for an example.