Overloading the train function

The default train function is opiniated and meant for demonstration purposes. For more advanced training, you should create your own training pipeline.

This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.

Define a custom setup struct

Create a new struct that inherits from AbstractSetup. This struct will hold the validation dataloader.

using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random

@concrete struct WithValidation <: AbstractSetup
    dataloader
end

Implement the custom train method

Implement a train method that takes your custom setup. This method will train the model while monitoring validation loss.

The method performs these steps:

  1. Prepare the training and validation data
  2. Set up the model with a feature wrapper
  3. Initialize the training state
  4. Train for multiple epochs, computing both training and validation loss
  5. Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
    return [(; u0 = batched_segments[:, 1, i],
                saveat = batched_tsteps[:, i],
                tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
            )
            for i in 1:size(batched_tsteps, 2)]
end

function HybridDynamicModels.train(backend::SGDBackend,
        model::AbstractLuxLayer,
        dataloader_train::SegmentedTimeSeries,
        experimental_setup::WithValidation,
        rng = Random.default_rng(),
        luxtype = Lux.f64)
    dataloader_train = luxtype(dataloader_train)
    dataloader_valid = luxtype(experimental_setup.dataloader)

    @assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"

    model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))

    ps, st = luxtype(Lux.setup(rng, model_with_wrapper))

    train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
    best_ps = ps.model
    best_st = st.model

    best_loss = luxtype(Inf)
    for epoch in 1:(backend.n_epochs)
        train_loss = luxtype(0.0)
        for (batched_segments, batched_tsteps) in dataloader_train
            _, loss, _, train_state = Training.single_train_step!(
                backend.adtype,
                backend.loss_fn,
                ((batched_segments, batched_tsteps), batched_segments),
                train_state)
            train_loss += loss
        end

        valid_loss = 0.0
        ps, st = train_state.parameters, train_state.states

        for (batched_segments, batched_data) in dataloader_valid
            segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
            valid_loss += backend.loss_fn(segment_pred, batched_segments)
        end

        println("Train loss: $train_loss")
        println("Validation loss: $valid_loss")
        if valid_loss < best_loss
            best_ps = ps.model
            best_st = st.model
            best_loss = train_loss
        end
    end

    return (; ps = best_ps, st = best_st)
end

Training example

To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.

using Lux, Optimisers, ComponentArrays
using Zygote

tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))

segment_length = 20
valid_length = 2
batchsize = 4

dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
                                                                segment_length,
                                                                valid_length,
                                                                batchsize,
                                                                partial_batch = true)

setup = WithValidation(dataloader_valid)

nn = Chain(
    Dense(2, 16, relu),
    Dense(16, 16, relu),
    Dense(16, 2)
)

function ar_step(layers, u, ps, t)
    return layers.nn(u, ps.nn)
end

model = ARModel(
    (;nn),
    ar_step;
    dt = tsteps[2] - tsteps[1],
)

backend = SGDBackend(Adam(0.01),
                    10,
                    AutoZygote(),
                    MSELoss())

train(backend, model, dataloader_train, setup);
Train loss: 2.731907965742889
Validation loss: 1.6362279075798152
Train loss: 2.2684367360726574
Validation loss: 1.4873929536553188
Train loss: 2.015973224559767
Validation loss: 1.4265492271458395
Train loss: 1.8973450667199954
Validation loss: 1.4130430032099714
Train loss: 1.804951633544782
Validation loss: 1.414079411021941
Train loss: 1.7453043345098846
Validation loss: 1.4419443343084206
Train loss: 1.7648579896662069
Validation loss: 1.489666780019336
Train loss: 1.811041956554174
Validation loss: 1.5229388654323004
Train loss: 1.8379690374858826
Validation loss: 1.5598124491019405
Train loss: 1.82838874692473
Validation loss: 1.5703146445552376