LoRA (Finetuning)#

Open in Colab

This is an example on fine-tuning Gemma with LoRA. It’s best to first read the finetuning colab to understand this one.

See the LoRA sampling tutorial if you just want to do inference with LoRA.

!pip install -q gemma
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Config updates#

If you’re familiar with the finetuning tutorial, switching to LoRA only require 3 changes to the trainer.

For an end-to-end example, see lora.py config.

1. Model#

Wrap the model in the gm.nn.LoRA. This will apply model surgery to replace all the linear and compatible layers with LoRA layers.

model = gm.nn.LoRA(
    rank=4,
    model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)

Internally, this uses the gemma.peft mini-library to perform model surgery.

2. Checkpoint#

Wrap the init transform in a gm.ckpts.SkipLoRA. The wrapper is required because the param structure with and without LoRA is different.

Only the initial pretrained weights are loaded, but the LoRA weights are kept to their random initialization.

init_transform = gm.ckpts.SkipLoRA(
    wrapped=gm.ckpts.LoadCheckpoint(
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    ),
)

Note: If you’re loading the weights directly with gm.ckpts.load_params, you can use the peft.split_params and peft.merge_params instead. See LoRA sampling for an example.

3. Optimizer#

Finally, we add a mask to the optimizer (with kd.optim.partial_updates), so only the LoRA weights are trained.

optimizer = kd.optim.partial_updates(
    optax.adafactor(learning_rate=0.005),
    # We only optimize the LoRA weights. The rest of the model is frozen.
    mask=kd.optim.select("lora"),
)

Training#

Data pipeline#

Like for the finetuning example, we recreate the tokenizer:

tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060]

And the data pipeline:

ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)
Disabling pygrain multi-processing (unsupported in colab).

We can decode an example from the batch to inspect the model input and check it is properly formatted:

text = tokenizer.decode(ex['input'][0])

print(text)
<start_of_turn>user
As far as battle mode, 64 is the best.<end_of_turn>
<start_of_turn>model
En ce qui concerne le mode bataille, 64 est le meilleur.

Trainer#

We then create the trainer, reusing the model, init_transform and optimizer created above:

trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optimizer,
)

Trainning can be launched with the .train() method.

Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.

state, aux = trainer.train()
Configuring ...
Initializing ...
Starting training loop at step 0

Checkpointing#

# TODO(epot): Doc on:
# * saving only LoRA params
# * Fuse params

Evaluation#

Here, we only perform a qualitative evaluation by sampling a prompt.

For more info on evals:

  • See the sampling tutorial for more info on running inference.

  • To add evals during training, see the Kauldron evaluator documentation.

sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

We test a sentence, using the same formatting used during fine-tuning:

sampler.chat('I\'m feeling happy today!')
"Je me sens heureux aujourd'hui !"

The model correctly translated our prompt to French!