Overloading the train
function
The default train
function is opiniated and meant for demonstration purposes. For more advanced training, you should create your own training pipeline.
This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.
Define a custom setup struct
Create a new struct that inherits from AbstractSetup
. This struct will hold the validation dataloader.
using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random
@concrete struct WithValidation <: AbstractSetup
dataloader
end
Implement the custom train
method
Implement a train
method that takes your custom setup. This method will train the model while monitoring validation loss.
The method performs these steps:
- Prepare the training and validation data
- Set up the model with a feature wrapper
- Initialize the training state
- Train for multiple epochs, computing both training and validation loss
- Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
return [(; u0 = batched_segments[:, 1, i],
saveat = batched_tsteps[:, i],
tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
)
for i in 1:size(batched_tsteps, 2)]
end
function HybridDynamicModels.train(backend::SGDBackend,
model::AbstractLuxLayer,
dataloader_train::SegmentedTimeSeries,
experimental_setup::WithValidation,
rng = Random.default_rng(),
luxtype = Lux.f64)
dataloader_train = luxtype(dataloader_train)
dataloader_valid = luxtype(experimental_setup.dataloader)
@assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"
model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))
ps, st = luxtype(Lux.setup(rng, model_with_wrapper))
train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
best_ps = ps.model
best_st = st.model
best_loss = luxtype(Inf)
for epoch in 1:(backend.n_epochs)
train_loss = luxtype(0.0)
for (batched_segments, batched_tsteps) in dataloader_train
_, loss, _, train_state = Training.single_train_step!(
backend.adtype,
backend.loss_fn,
((batched_segments, batched_tsteps), batched_segments),
train_state)
train_loss += loss
end
valid_loss = 0.0
ps, st = train_state.parameters, train_state.states
for (batched_segments, batched_data) in dataloader_valid
segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
valid_loss += backend.loss_fn(segment_pred, batched_segments)
end
println("Train loss: $train_loss")
println("Validation loss: $valid_loss")
if valid_loss < best_loss
best_ps = ps.model
best_st = st.model
best_loss = train_loss
end
end
return (; ps = best_ps, st = best_st)
end
Training example
To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.
using Lux, Optimisers, ComponentArrays
using Zygote
tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))
segment_length = 20
valid_length = 2
batchsize = 4
dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
segment_length,
valid_length,
batchsize,
partial_batch = true)
setup = WithValidation(dataloader_valid)
nn = Chain(
Dense(2, 16, relu),
Dense(16, 16, relu),
Dense(16, 2)
)
function ar_step(layers, u, ps, t)
return layers.nn(u, ps.nn)
end
model = ARModel(
(;nn),
ar_step;
dt = tsteps[2] - tsteps[1],
)
backend = SGDBackend(Adam(0.01),
10,
AutoZygote(),
MSELoss())
train(backend, model, dataloader_train, setup);
Train loss: 2.523880237045662
Validation loss: 3.056198726001546
Train loss: 2.1676965256310634
Validation loss: 3.086886059796705
Train loss: 2.161537972025759
Validation loss: 2.4548079497617836
Train loss: 2.208427288956004
Validation loss: 2.0073858264457116
Train loss: 2.230083799643335
Validation loss: 1.8235838673918003
Train loss: 2.603968194347684
Validation loss: 1.8157940310334248
Train loss: 2.8612241508403367
Validation loss: 1.850207205028822
Train loss: 3.6497091230190506
Validation loss: 1.8883697704921003
Train loss: 4.324854071437454
Validation loss: 1.935994465987359
Train loss: 3.913046821833405
Validation loss: 1.9912229041817384