Overloading the train
function
The default train
function is basic and meant for a demonstration purpose. For more advanced training, you should create your own training pipeline.
This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.
Define a custom setup struct
Create a new struct that inherits from AbstractSetup
. This struct will hold the validation dataloader.
using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random
@concrete struct WithValidation <: AbstractSetup
dataloader
end
Implement the custom train method
Implement a train
method that takes your custom setup. This method will train the model while monitoring validation loss.
The method performs these steps:
- Prepare the training and validation data
- Set up the model with a feature wrapper
- Initialize the training state
- Train for multiple epochs, computing both training and validation loss
- Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
return [(; u0 = batched_segments[:, 1, i],
saveat = batched_tsteps[:, i],
tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
)
for i in 1:size(batched_tsteps, 2)]
end
function HybridDynamicModels.train(backend::SGDBackend,
model::AbstractLuxLayer,
dataloader_train::SegmentedTimeSeries,
experimental_setup::WithValidation,
rng = Random.default_rng(),
luxtype = Lux.f64)
dataloader_train = luxtype(dataloader_train)
dataloader_valid = luxtype(experimental_setup.dataloader)
@assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"
model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))
ps, st = luxtype(Lux.setup(rng, model_with_wrapper))
train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
best_ps = ps.model
best_st = st.model
best_loss = luxtype(Inf)
for epoch in 1:(backend.n_epochs)
train_loss = luxtype(0.0)
for (batched_segments, batched_tsteps) in dataloader_train
_, loss, _, train_state = Training.single_train_step!(
backend.adtype,
backend.loss_fn,
((batched_segments, batched_tsteps), batched_segments),
train_state)
train_loss += loss
end
valid_loss = 0.0
ps, st = train_state.parameters, train_state.states
for (batched_segments, batched_data) in dataloader_valid
segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
valid_loss += backend.loss_fn(segment_pred, batched_segments)
end
println("Train loss: $train_loss")
println("Validation loss: $valid_loss")
if valid_loss < best_loss
best_ps = ps.model
best_st = st.model
best_loss = train_loss
end
end
return (; ps = best_ps, st = best_st)
end
Training example
To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.
using Lux, Optimisers, ComponentArrays
using Zygote
tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))
segment_length = 20
valid_length = 2
batchsize = 4
dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
segment_length,
valid_length,
batchsize,
partial_batch = true)
setup = WithValidation(dataloader_valid)
nn = Chain(
Dense(2, 16, relu),
Dense(16, 16, relu),
Dense(16, 2)
)
function ar_step(layers, u, ps, t)
return layers.nn(u, ps.nn)
end
model = ARModel(
(;nn),
ar_step;
dt = tsteps[2] - tsteps[1],
)
backend = SGDBackend(Adam(0.01),
10,
AutoZygote(),
MSELoss())
train(backend, model, dataloader_train, setup);
Train loss: 3.7005512481893352
Validation loss: 2.1913953256249195
Train loss: 3.181255767841896
Validation loss: 1.9473052821992503
Train loss: 3.235959528823285
Validation loss: 1.7110689140445294
Train loss: 3.385303107874046
Validation loss: 1.6817806235334052
Train loss: 3.097253291728708
Validation loss: 1.77109502205731
Train loss: 2.8746393087007625
Validation loss: 1.715721152018439
Train loss: 2.712255107080331
Validation loss: 1.585378981176048
Train loss: 2.605672247837778
Validation loss: 1.5064687583008494
Train loss: 2.5266623138891835
Validation loss: 1.4380432306578121
Train loss: 2.4362662122957675
Validation loss: 1.3716984413838595