Overloading the train function

The default train function is basic and meant for a demonstration purpose. For more advanced training, you should create your own training pipeline.

This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.

Define a custom setup struct

Create a new struct that inherits from AbstractSetup. This struct will hold the validation dataloader.

using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random

@concrete struct WithValidation <: AbstractSetup
    dataloader
end

Implement the custom train method

Implement a train method that takes your custom setup. This method will train the model while monitoring validation loss.

The method performs these steps:

  1. Prepare the training and validation data
  2. Set up the model with a feature wrapper
  3. Initialize the training state
  4. Train for multiple epochs, computing both training and validation loss
  5. Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
    return [(; u0 = batched_segments[:, 1, i],
                saveat = batched_tsteps[:, i],
                tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
            )
            for i in 1:size(batched_tsteps, 2)]
end

function HybridDynamicModels.train(backend::SGDBackend,
        model::AbstractLuxLayer,
        dataloader_train::SegmentedTimeSeries,
        experimental_setup::WithValidation,
        rng = Random.default_rng(),
        luxtype = Lux.f64)
    dataloader_train = luxtype(dataloader_train)
    dataloader_valid = luxtype(experimental_setup.dataloader)

    @assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"

    model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))

    ps, st = luxtype(Lux.setup(rng, model_with_wrapper))

    train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
    best_ps = ps.model
    best_st = st.model

    best_loss = luxtype(Inf)
    for epoch in 1:(backend.n_epochs)
        train_loss = luxtype(0.0)
        for (batched_segments, batched_tsteps) in dataloader_train
            _, loss, _, train_state = Training.single_train_step!(
                backend.adtype,
                backend.loss_fn,
                ((batched_segments, batched_tsteps), batched_segments),
                train_state)
            train_loss += loss
        end

        valid_loss = 0.0
        ps, st = train_state.parameters, train_state.states

        for (batched_segments, batched_data) in dataloader_valid
            segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
            valid_loss += backend.loss_fn(segment_pred, batched_segments)
        end

        println("Train loss: $train_loss")
        println("Validation loss: $valid_loss")
        if valid_loss < best_loss
            best_ps = ps.model
            best_st = st.model
            best_loss = train_loss
        end
    end

    return (; ps = best_ps, st = best_st)
end

Training example

To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.

using Lux, Optimisers, ComponentArrays
using Zygote

tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))

segment_length = 20
valid_length = 2
batchsize = 4

dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
                                                                segment_length,
                                                                valid_length,
                                                                batchsize,
                                                                partial_batch = true)

setup = WithValidation(dataloader_valid)

nn = Chain(
    Dense(2, 16, relu),
    Dense(16, 16, relu),
    Dense(16, 2)
)

function ar_step(layers, u, ps, t)
    return layers.nn(u, ps.nn)
end

model = ARModel(
    (;nn),
    ar_step;
    dt = tsteps[2] - tsteps[1],
)

backend = SGDBackend(Adam(0.01),
                    10,
                    AutoZygote(),
                    MSELoss())

train(backend, model, dataloader_train, setup);
Train loss: 3.7005512481893352
Validation loss: 2.1913953256249195
Train loss: 3.181255767841896
Validation loss: 1.9473052821992503
Train loss: 3.235959528823285
Validation loss: 1.7110689140445294
Train loss: 3.385303107874046
Validation loss: 1.6817806235334052
Train loss: 3.097253291728708
Validation loss: 1.77109502205731
Train loss: 2.8746393087007625
Validation loss: 1.715721152018439
Train loss: 2.712255107080331
Validation loss: 1.585378981176048
Train loss: 2.605672247837778
Validation loss: 1.5064687583008494
Train loss: 2.5266623138891835
Validation loss: 1.4380432306578121
Train loss: 2.4362662122957675
Validation loss: 1.3716984413838595