Overloading the train function

The default train function is opiniated and meant for demonstration purposes. For more advanced training, you should create your own training pipeline.

This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.

Define a custom setup struct

Create a new struct that inherits from AbstractSetup. This struct will hold the validation dataloader.

using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random

@concrete struct WithValidation <: AbstractSetup
    dataloader
end

Implement the custom train method

Implement a train method that takes your custom setup. This method will train the model while monitoring validation loss.

The method performs these steps:

  1. Prepare the training and validation data
  2. Set up the model with a feature wrapper
  3. Initialize the training state
  4. Train for multiple epochs, computing both training and validation loss
  5. Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
    return [(; u0 = batched_segments[:, 1, i],
                saveat = batched_tsteps[:, i],
                tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
            )
            for i in 1:size(batched_tsteps, 2)]
end

function HybridDynamicModels.train(backend::SGDBackend,
        model::AbstractLuxLayer,
        dataloader_train::SegmentedTimeSeries,
        experimental_setup::WithValidation,
        rng = Random.default_rng(),
        luxtype = Lux.f64)
    dataloader_train = luxtype(dataloader_train)
    dataloader_valid = luxtype(experimental_setup.dataloader)

    @assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"

    model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))

    ps, st = luxtype(Lux.setup(rng, model_with_wrapper))

    train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
    best_ps = ps.model
    best_st = st.model

    best_loss = luxtype(Inf)
    for epoch in 1:(backend.n_epochs)
        train_loss = luxtype(0.0)
        for (batched_segments, batched_tsteps) in dataloader_train
            _, loss, _, train_state = Training.single_train_step!(
                backend.adtype,
                backend.loss_fn,
                ((batched_segments, batched_tsteps), batched_segments),
                train_state)
            train_loss += loss
        end

        valid_loss = 0.0
        ps, st = train_state.parameters, train_state.states

        for (batched_segments, batched_data) in dataloader_valid
            segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
            valid_loss += backend.loss_fn(segment_pred, batched_segments)
        end

        println("Train loss: $train_loss")
        println("Validation loss: $valid_loss")
        if valid_loss < best_loss
            best_ps = ps.model
            best_st = st.model
            best_loss = train_loss
        end
    end

    return (; ps = best_ps, st = best_st)
end

Training example

To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.

using Lux, Optimisers, ComponentArrays
using Zygote

tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))

segment_length = 20
valid_length = 2
batchsize = 4

dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
                                                                segment_length,
                                                                valid_length,
                                                                batchsize,
                                                                partial_batch = true)

setup = WithValidation(dataloader_valid)

nn = Chain(
    Dense(2, 16, relu),
    Dense(16, 16, relu),
    Dense(16, 2)
)

function ar_step(layers, u, ps, t)
    return layers.nn(u, ps.nn)
end

model = ARModel(
    (;nn),
    ar_step;
    dt = tsteps[2] - tsteps[1],
)

backend = SGDBackend(Adam(0.01),
                    10,
                    AutoZygote(),
                    MSELoss())

train(backend, model, dataloader_train, setup);
Train loss: 2.523880237045662
Validation loss: 3.056198726001546
Train loss: 2.1676965256310634
Validation loss: 3.086886059796705
Train loss: 2.161537972025759
Validation loss: 2.4548079497617836
Train loss: 2.208427288956004
Validation loss: 2.0073858264457116
Train loss: 2.230083799643335
Validation loss: 1.8235838673918003
Train loss: 2.603968194347684
Validation loss: 1.8157940310334248
Train loss: 2.8612241508403367
Validation loss: 1.850207205028822
Train loss: 3.6497091230190506
Validation loss: 1.8883697704921003
Train loss: 4.324854071437454
Validation loss: 1.935994465987359
Train loss: 3.913046821833405
Validation loss: 1.9912229041817384