Overloading the train function
The default train function is opiniated and meant for demonstration purposes. For more advanced training, you should create your own training pipeline.
This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.
Define a custom setup struct
Create a new struct that inherits from AbstractSetup. This struct will hold the validation dataloader.
using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random
@concrete struct WithValidation <: AbstractSetup
dataloader
endImplement the custom train method
Implement a train method that takes your custom setup. This method will train the model while monitoring validation loss.
The method performs these steps:
- Prepare the training and validation data
- Set up the model with a feature wrapper
- Initialize the training state
- Train for multiple epochs, computing both training and validation loss
- Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
return [(; u0 = batched_segments[:, 1, i],
saveat = batched_tsteps[:, i],
tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
)
for i in 1:size(batched_tsteps, 2)]
end
function HybridDynamicModels.train(backend::SGDBackend,
model::AbstractLuxLayer,
dataloader_train::SegmentedTimeSeries,
experimental_setup::WithValidation,
rng = Random.default_rng(),
luxtype = Lux.f64)
dataloader_train = luxtype(dataloader_train)
dataloader_valid = luxtype(experimental_setup.dataloader)
@assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"
model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))
ps, st = luxtype(Lux.setup(rng, model_with_wrapper))
train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
best_ps = ps.model
best_st = st.model
best_loss = luxtype(Inf)
for epoch in 1:(backend.n_epochs)
train_loss = luxtype(0.0)
for (batched_segments, batched_tsteps) in dataloader_train
_, loss, _, train_state = Training.single_train_step!(
backend.adtype,
backend.loss_fn,
((batched_segments, batched_tsteps), batched_segments),
train_state)
train_loss += loss
end
valid_loss = 0.0
ps, st = train_state.parameters, train_state.states
for (batched_segments, batched_data) in dataloader_valid
segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
valid_loss += backend.loss_fn(segment_pred, batched_segments)
end
println("Train loss: $train_loss")
println("Validation loss: $valid_loss")
if valid_loss < best_loss
best_ps = ps.model
best_st = st.model
best_loss = train_loss
end
end
return (; ps = best_ps, st = best_st)
endTraining example
To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.
using Lux, Optimisers, ComponentArrays
using Zygote
tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))
segment_length = 20
valid_length = 2
batchsize = 4
dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
segment_length,
valid_length,
batchsize,
partial_batch = true)
setup = WithValidation(dataloader_valid)
nn = Chain(
Dense(2, 16, relu),
Dense(16, 16, relu),
Dense(16, 2)
)
function ar_step(layers, u, ps, t)
return layers.nn(u, ps.nn)
end
model = ARModel(
(;nn),
ar_step;
dt = tsteps[2] - tsteps[1],
)
backend = SGDBackend(Adam(0.01),
10,
AutoZygote(),
MSELoss())
train(backend, model, dataloader_train, setup);Train loss: 2.731907965742889
Validation loss: 1.6362279075798152
Train loss: 2.2684367360726574
Validation loss: 1.4873929536553188
Train loss: 2.015973224559767
Validation loss: 1.4265492271458395
Train loss: 1.8973450667199954
Validation loss: 1.4130430032099714
Train loss: 1.804951633544782
Validation loss: 1.414079411021941
Train loss: 1.7453043345098846
Validation loss: 1.4419443343084206
Train loss: 1.7648579896662069
Validation loss: 1.489666780019336
Train loss: 1.811041956554174
Validation loss: 1.5229388654323004
Train loss: 1.8379690374858826
Validation loss: 1.5598124491019405
Train loss: 1.82838874692473
Validation loss: 1.5703146445552376