Overloading the train function

The default train function is opiniated and meant for demonstration purposes. For more advanced training, you should create your own training pipeline.

This tutorial shows how to build a custom training pipeline that includes validation to prevent overfitting.

Define a custom setup struct

Create a new struct that inherits from AbstractSetup. This struct will hold the validation dataloader.

using HybridDynamicModels
import HybridDynamicModels: AbstractSetup
import LuxCore: AbstractLuxLayer
import ConcreteStructs: @concrete
import Random

@concrete struct WithValidation <: AbstractSetup
    dataloader
end

Implement the custom train method

Implement a train method that takes your custom setup. This method will train the model while monitoring validation loss.

The method performs these steps:

  1. Prepare the training and validation data
  2. Set up the model with a feature wrapper
  3. Initialize the training state
  4. Train for multiple epochs, computing both training and validation loss
  5. Save the best model parameters based on validation loss
function feature_wrapper((batched_segments, batched_tsteps))
    return [(; u0 = batched_segments[:, 1, i],
                saveat = batched_tsteps[:, i],
                tspan = (batched_tsteps[1, i], batched_tsteps[end, i])
            )
            for i in 1:size(batched_tsteps, 2)]
end

function HybridDynamicModels.train(backend::SGDBackend,
        model::AbstractLuxLayer,
        dataloader_train::SegmentedTimeSeries,
        experimental_setup::WithValidation,
        rng = Random.default_rng(),
        luxtype = Lux.f64)
    dataloader_train = luxtype(dataloader_train)
    dataloader_valid = luxtype(experimental_setup.dataloader)

    @assert length(dataloader_train)==length(dataloader_valid) "The training and validation dataloaders must have the same number of segments"

    model_with_wrapper = Chain((; wrapper = Lux.WrappedFunction(feature_wrapper), model = model))

    ps, st = luxtype(Lux.setup(rng, model_with_wrapper))

    train_state = Training.TrainState(model_with_wrapper, ps, st, backend.opt)
    best_ps = ps.model
    best_st = st.model

    best_loss = luxtype(Inf)
    for epoch in 1:(backend.n_epochs)
        train_loss = luxtype(0.0)
        for (batched_segments, batched_tsteps) in dataloader_train
            _, loss, _, train_state = Training.single_train_step!(
                backend.adtype,
                backend.loss_fn,
                ((batched_segments, batched_tsteps), batched_segments),
                train_state)
            train_loss += loss
        end

        valid_loss = 0.0
        ps, st = train_state.parameters, train_state.states

        for (batched_segments, batched_data) in dataloader_valid
            segment_pred, _ = model_with_wrapper((batched_segments, batched_data), ps, st)
            valid_loss += backend.loss_fn(segment_pred, batched_segments)
        end

        println("Train loss: $train_loss")
        println("Validation loss: $valid_loss")
        if valid_loss < best_loss
            best_ps = ps.model
            best_st = st.model
            best_loss = train_loss
        end
    end

    return (; ps = best_ps, st = best_st)
end

Training example

To use this custom pipeline, create training and validation dataloaders, set up the custom configuration, and call the train function.

using Lux, Optimisers, ComponentArrays
using Zygote

tsteps = range(0, stop=20.0, length=201)
data = randn(2, length(tsteps))

segment_length = 20
valid_length = 2
batchsize = 4

dataloader_train, dataloader_valid = create_train_val_loaders((data, tsteps);
                                                                segment_length,
                                                                valid_length,
                                                                batchsize,
                                                                partial_batch = true)

setup = WithValidation(dataloader_valid)

nn = Chain(
    Dense(2, 16, relu),
    Dense(16, 16, relu),
    Dense(16, 2)
)

function ar_step(layers, u, ps, t)
    return layers.nn(u, ps.nn)
end

model = ARModel(
    (;nn),
    ar_step;
    dt = tsteps[2] - tsteps[1],
)

backend = SGDBackend(Adam(0.01),
                    10,
                    AutoZygote(),
                    MSELoss())

train(backend, model, dataloader_train, setup);
Train loss: 1.1552638008175688e15
Validation loss: 6.802254272371301
Train loss: 62879.79786303385
Validation loss: 6.837702604791636
Train loss: 8.25958965539453e7
Validation loss: 6.972318183088622
Train loss: 2.338439181347938e9
Validation loss: 7.051718011424168
Train loss: 1.5914002675883633e10
Validation loss: 7.151404031075922
Train loss: 3.8594020015898026e10
Validation loss: 7.0730953623304735
Train loss: 3.4888667778198586e10
Validation loss: 6.869624457300776
Train loss: 1.692476618874209e10
Validation loss: 6.68011565554518
Train loss: 6.497231044060102e9
Validation loss: 6.516080968156058
Train loss: 2.6908328012798615e9
Validation loss: 6.386999526514899