Training with SGDBackend
This tutorial demonstrates how to use the SGDBackend
for training a hybrid autoregressive model modelling the classic hare-lynx predator-prey system, where the predation interaction is learned via neural networks while maintaining mechanistic constraints for birth and death processes.
The train
function provided by HybridDynamicModels
is an experimental feature, exposed for demonstration purposes. Users are encouraged to implement their own train
function to gain more control over the training process; see Overloading the train
function.
Importing necessary packages
In order to use the SGDBackend, we'll need to manually load Lux
, Optimisers
, and ComponentArrays
. We additionally load Zygote
for automatic differentiation, ParameterSchedulers
for learning rate scheduling, and Plots
, DataFrames
, DelimitedFiles
, and HTTP
for data handling and visualization.
using Lux, Optimisers, ComponentArrays
using Zygote
using HybridDynamicModels
using ParameterSchedulers
using Random
using Plots
using DataFrames, DelimitedFiles, HTTP
const luxtype = Lux.f64
f64 (generic function with 1 method)
Data loading
Load the Lynx-Hare population dataset:
url = "http://people.whitman.edu/~hundledr/courses/M250F03/LynxHare.txt"
data = readdlm(IOBuffer(HTTP.get(url).body), ' ') |> luxtype
df_data = DataFrame(Year = data[:, 1], Hare = data[:, 2], Lynx = data[:, 3])
# Visualize observed data (hare and lynx)
plt_data = plot(df_data.Year, df_data.Hare, label = "Hare", xlabel = "Year",
ylabel = "Population", title = "Observed Hare-Lynx Data")
plot!(plt_data, df_data.Year, df_data.Lynx, label = "Lynx")
display(plt_data)
Data preparation
Prepare training and test datasets:
tsteps = Vector(df_data.Year) |> luxtype
# Extract hare and lynx data
hare_lynx_data = Array(df_data[:, Not(:Year)])' |> luxtype
hare_lynx_data ./= maximum(hare_lynx_data)
# Data array: [hare, lynx]
data_array = hare_lynx_data |> luxtype
forecast_length = 20
test_idx = size(data_array, 2) - forecast_length + 1:size(data_array, 2)
# Create training dataloader
dataloader_train = SegmentedTimeSeries(
(data_array[:, Not(test_idx)], tsteps[Not(test_idx)]);
segment_length = 4, shift = 2, batchsize = 20)
SegmentedTimeSeries
Time series length: 71
Segment length: 4
Shift: 2 (50.0% overlap)
Batch size: 20
Total segments: 34
Total batches: 1
Model definition
Define a hare-lynx predator-prey model where the predation interaction is learned via neural networks, while birth and death processes follow mechanistic rules:
# Neural network for hare-lynx predation interactions
hlsize = 2^4
neural_interactions = Chain(Dense(2, hlsize, relu),
Dense(hlsize, hlsize, relu),
Dense(hlsize, 1)) # Output: predation rate
# Learnable ecological parameters
mechanistic_params = ParameterLayer(init_value = (
hare_birth = [0.8],
hare_death = [0.1],
lynx_death = [0.2] ),
constraint = NamedTupleConstraint((hare_birth = BoxConstraint([0.0], [2.0]),
hare_death = BoxConstraint([0.001], [1.0]),
lynx_death = BoxConstraint([0.001], [1.0]))
))
# Hybrid ecosystem dynamics
function ecosystem_step(layers, u, ps, t)
hare, lynx = max.(u, 0.) # Unpack state variables
params = layers.mechanistic_params(ps.mechanistic_params)
# Neural network: predation rate
predation_input = [hare, lynx]
predation_rate = layers.neural_interactions(predation_input, ps.neural_interactions)[1]
# Mechanistic hare dynamics
hare_birth = params.hare_birth[1] * hare
hare_predation = -predation_rate * hare * lynx
hare_natural_death = -params.hare_death[1] * hare
# Mechanistic lynx dynamics
lynx_predation_gain = predation_rate * hare * lynx # Lynx gain from predation
lynx_death = -params.lynx_death[1] * lynx
# Return derivatives
return [
hare_birth + hare_predation + hare_natural_death, # Hare
lynx_predation_gain + lynx_death # Lynx
]
end
# Create autoregressive model
model = ARModel(
(;neural_interactions, mechanistic_params),
ecosystem_step;
dt = tsteps[2] - tsteps[1],
);
Training configuration
Configure training with learning rate scheduling and callbacks:
# Learning rate schedule: exponential decay
lr_schedule = Step(1e-2, 0.9, 200)
# Callback for monitoring and learning rate adjustment
function callback(loss, epoch, ts)
if epoch % 20 == 0
current_lr = lr_schedule(epoch)
@info "Epoch $epoch: Loss = $loss, LR = $current_lr"
Optimisers.adjust!(ts.optimizer_state, current_lr)
end
end
# Training backend configuration
backend = SGDBackend(
AdamW(eta = 1e-2, lambda = 1e-4), # Optimizer with weight decay
2000, # Number of epochs
AutoZygote(), # Automatic differentiation
MSELoss(), # Loss function
callback # Training callback
)
HybridDynamicModelsLuxExt.SGDBackend(AdamW(eta=0.01, beta=(0.9, 0.999), lam
bda=0.0001, epsilon=1.0e-8, couple=true), 2000, AutoZygote(), GenericLossFu
nction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}(Lux.Los
sFunctionImpl.l2_distance_loss, Statistics.mean), Main.var"##WeaveSandBox#2
37".callback)
Training
Train the model with initial condition inference:
@info "Starting training..."
result = train(backend, model, dataloader_train, InferICs(true));
Starting training...
Epoch 20: Loss = 0.03873773946887725, LR = 0.01
Epoch 40: Loss = 0.027979128959718824, LR = 0.01
Epoch 60: Loss = 0.022409160359493113, LR = 0.01
Epoch 80: Loss = 0.018925108275828755, LR = 0.01
Epoch 100: Loss = 0.016663934203832453, LR = 0.01
Epoch 120: Loss = 0.014817495316206603, LR = 0.01
Epoch 140: Loss = 0.01356381551281366, LR = 0.01
Epoch 160: Loss = 0.01269433585602035, LR = 0.01
Epoch 180: Loss = 0.0121249254682, LR = 0.01
Epoch 200: Loss = 0.01175334403578352, LR = 0.01
Epoch 220: Loss = 0.011506516928238637, LR = 0.009000000000000001
Epoch 240: Loss = 0.011328215335286376, LR = 0.009000000000000001
Epoch 260: Loss = 0.011210211106416299, LR = 0.009000000000000001
Epoch 280: Loss = 0.011114126196432757, LR = 0.009000000000000001
Epoch 300: Loss = 0.011041389525203397, LR = 0.009000000000000001
Epoch 320: Loss = 0.010977631405769572, LR = 0.009000000000000001
Epoch 340: Loss = 0.010924746045140836, LR = 0.009000000000000001
Epoch 360: Loss = 0.010854294104474357, LR = 0.009000000000000001
Epoch 380: Loss = 0.010811675976299234, LR = 0.009000000000000001
Epoch 400: Loss = 0.010827760959276966, LR = 0.009000000000000001
Epoch 420: Loss = 0.010748585264812035, LR = 0.008100000000000001
Epoch 440: Loss = 0.010715101099053773, LR = 0.008100000000000001
Epoch 460: Loss = 0.010687222062239553, LR = 0.008100000000000001
Epoch 480: Loss = 0.010656570000713796, LR = 0.008100000000000001
Epoch 500: Loss = 0.010676247069779784, LR = 0.008100000000000001
Epoch 520: Loss = 0.010612433673517996, LR = 0.008100000000000001
Epoch 540: Loss = 0.010716463204923645, LR = 0.008100000000000001
Epoch 560: Loss = 0.010582059609425704, LR = 0.008100000000000001
Epoch 580: Loss = 0.010555874315144423, LR = 0.008100000000000001
Epoch 600: Loss = 0.010535615417614257, LR = 0.008100000000000001
Epoch 620: Loss = 0.010611595236873762, LR = 0.007290000000000001
Epoch 640: Loss = 0.010547863299101373, LR = 0.007290000000000001
Epoch 660: Loss = 0.010516973078894359, LR = 0.007290000000000001
Epoch 680: Loss = 0.010496783442717683, LR = 0.007290000000000001
Epoch 700: Loss = 0.010471417323718524, LR = 0.007290000000000001
Epoch 720: Loss = 0.010459783061899647, LR = 0.007290000000000001
Epoch 740: Loss = 0.010447165230885971, LR = 0.007290000000000001
Epoch 760: Loss = 0.010440231276687317, LR = 0.007290000000000001
Epoch 780: Loss = 0.010445250783760084, LR = 0.007290000000000001
Epoch 800: Loss = 0.010426324614561581, LR = 0.007290000000000001
Epoch 820: Loss = 0.010604130188737461, LR = 0.006561
Epoch 840: Loss = 0.010428667651899261, LR = 0.006561
Epoch 860: Loss = 0.010383217053634384, LR = 0.006561
Epoch 880: Loss = 0.010372620287535769, LR = 0.006561
Epoch 900: Loss = 0.010359356210476314, LR = 0.006561
Epoch 920: Loss = 0.010381483544722323, LR = 0.006561
Epoch 940: Loss = 0.010417505759747314, LR = 0.006561
Epoch 960: Loss = 0.010368541610752093, LR = 0.006561
Epoch 980: Loss = 0.010362711800477781, LR = 0.006561
Epoch 1000: Loss = 0.010325603798168352, LR = 0.006561
Epoch 1020: Loss = 0.010316569642107419, LR = 0.005904900000000001
Epoch 1040: Loss = 0.010294112562181528, LR = 0.005904900000000001
Epoch 1060: Loss = 0.010303497583659636, LR = 0.005904900000000001
Epoch 1080: Loss = 0.010269862771782146, LR = 0.005904900000000001
Epoch 1100: Loss = 0.010280936915460545, LR = 0.005904900000000001
Epoch 1120: Loss = 0.01047574168487166, LR = 0.005904900000000001
Epoch 1140: Loss = 0.010273516203555579, LR = 0.005904900000000001
Epoch 1160: Loss = 0.010238538178947359, LR = 0.005904900000000001
Epoch 1180: Loss = 0.01023078775392528, LR = 0.005904900000000001
Epoch 1200: Loss = 0.01022328785943326, LR = 0.005904900000000001
Epoch 1220: Loss = 0.010208889399069877, LR = 0.00531441
Epoch 1240: Loss = 0.010196224927603203, LR = 0.00531441
Epoch 1260: Loss = 0.010187362041633853, LR = 0.00531441
Epoch 1280: Loss = 0.01018491324454769, LR = 0.00531441
Epoch 1300: Loss = 0.010202709220879066, LR = 0.00531441
Epoch 1320: Loss = 0.010183335586376466, LR = 0.00531441
Epoch 1340: Loss = 0.010302667929520674, LR = 0.00531441
Epoch 1360: Loss = 0.010196505653380105, LR = 0.00531441
Epoch 1380: Loss = 0.010166814283950555, LR = 0.00531441
Epoch 1400: Loss = 0.010168618083352222, LR = 0.00531441
Epoch 1420: Loss = 0.010148112460842292, LR = 0.004782969000000001
Epoch 1440: Loss = 0.010136957351447682, LR = 0.004782969000000001
Epoch 1460: Loss = 0.010144556036410901, LR = 0.004782969000000001
Epoch 1480: Loss = 0.010130710156542422, LR = 0.004782969000000001
Epoch 1500: Loss = 0.010116020633853567, LR = 0.004782969000000001
Epoch 1520: Loss = 0.01012517654879442, LR = 0.004782969000000001
Epoch 1540: Loss = 0.010141771974832289, LR = 0.004782969000000001
Epoch 1560: Loss = 0.010112986928497634, LR = 0.004782969000000001
Epoch 1580: Loss = 0.010121697194996006, LR = 0.004782969000000001
Epoch 1600: Loss = 0.010095370739226716, LR = 0.004782969000000001
Epoch 1620: Loss = 0.01008619396774289, LR = 0.004304672100000001
Epoch 1640: Loss = 0.010086602768547514, LR = 0.004304672100000001
Epoch 1660: Loss = 0.01007847832736047, LR = 0.004304672100000001
Epoch 1680: Loss = 0.010125348766076119, LR = 0.004304672100000001
Epoch 1700: Loss = 0.010078779028739831, LR = 0.004304672100000001
Epoch 1720: Loss = 0.010081668635631882, LR = 0.004304672100000001
Epoch 1740: Loss = 0.010060679702467022, LR = 0.004304672100000001
Epoch 1760: Loss = 0.010062449142988109, LR = 0.004304672100000001
Epoch 1780: Loss = 0.01006080967449304, LR = 0.004304672100000001
Epoch 1800: Loss = 0.010078060426619174, LR = 0.004304672100000001
Epoch 1820: Loss = 0.010052871156739526, LR = 0.003874204890000001
Epoch 1840: Loss = 0.010044562787516564, LR = 0.003874204890000001
Epoch 1860: Loss = 0.010041155926342082, LR = 0.003874204890000001
Epoch 1880: Loss = 0.01002997261334782, LR = 0.003874204890000001
Epoch 1900: Loss = 0.010030611816642528, LR = 0.003874204890000001
Epoch 1920: Loss = 0.01004132036491225, LR = 0.003874204890000001
Epoch 1940: Loss = 0.010026171640353196, LR = 0.003874204890000001
Epoch 1960: Loss = 0.010052227517796365, LR = 0.003874204890000001
Epoch 1980: Loss = 0.010027353914943433, LR = 0.003874204890000001
Epoch 2000: Loss = 0.010032788673307893, LR = 0.003874204890000001
Results visualization
Visualize training fit and test predictions for the hare-lynx ecosystem:
# Colors: blue for hare, red for lynx
hare_color = "#ffd166"
lynx_color = "#ef476f"
# Function to plot training results
function plot_training_results(dataloader, result, model)
plt = plot(title = "Training Results", xlabel = "Year",
ylabel = "Population", legend = :topright)
dataloader_tokenized = tokenize(dataloader)
for tok in tokens(dataloader_tokenized)
segment_data, segment_tsteps = dataloader_tokenized[tok]
ics = result.ics[tok].u0
pred, _ = model(
(; u0 = ics, saveat = segment_tsteps,
tspan = (segment_tsteps[1], segment_tsteps[end])),
result.ps, result.st)
# Plot observed data
scatter!(plt, segment_tsteps, segment_data[1, :],
label = (tok == 1 ? "Hare Data" : ""),
color = hare_color, markersize = 4, alpha = 0.7)
scatter!(plt, segment_tsteps, segment_data[2, :],
label = (tok == 1 ? "Lynx Data" : ""),
color = lynx_color, markersize = 4, alpha = 0.7)
# Plot predictions
plot!(plt, segment_tsteps, pred[1, :],
label = (tok == 1 ? "Hare Predicted" : ""),
color = hare_color, linewidth = 2)
plot!(plt, segment_tsteps, pred[2, :],
label = (tok == 1 ? "Lynx Predicted" : ""),
color = lynx_color, linewidth = 2)
end
return plt
end
# Plot training results
plt_train = plot_training_results(dataloader_train, result, model)
Forecast on test data:
tsteps_test = tsteps[test_idx]
data_test = data_array[:, test_idx]
u0, t0 = result.ics[end]
preds, _ = model((; u0 = u0, tspan = (t0, tsteps_test[end]), saveat = tsteps_test),
result.ps, result.st)
# Plot test predictions
plt_test = plot(title = "Test Predictions", xlabel = "Year", ylabel = "Population", legend = :topright)
scatter!(plt_test, tsteps_test, data_test[1, :], label = "Hare Data", color = hare_color, markersize = 4, alpha = 0.7)
scatter!(plt_test, tsteps_test, data_test[2, :], label = "Lynx Data", color = lynx_color, markersize = 4, alpha = 0.7)
plot!(plt_test, tsteps_test, preds[1, :], label = "Hare Predicted", color = hare_color, linewidth = 2)
plot!(plt_test, tsteps_test, preds[2, :], label = "Lynx Predicted", color = lynx_color, linewidth = 2)
Some final notes
- When training a neural network-based parametrization, it is usually best practice to use a validation loss to avoid overfitting. This can be implemented by creating a separate validation dataloader (see
create_train_val_loaders
) and modifying the training loop to compute validation loss at intervals, overloading thetrain
function. Check out the Overloading thetrain
function tutorial for an example.