LoRA (Sampling)

LoRA (Sampling)#

Open in Colab

Example on using LoRA with Gemma (for inference). For an example of fine-tuning with LoRA, see LoRA finetuning example.

!pip install -q gemma
# Common imports
import os
import jax
import jax.numpy as jnp
import treescope

# Gemma imports
from gemma import gm
from gemma import peft  # Parameter fine-tuning module

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Initializing the model#

To use Gemma with LoRA, simply wrap any Gemma model in gm.nn.LoRA:

model = gm.nn.LoRA(
    rank=4,
    model=gm.nn.Gemma3_4B(text_only=True),
)

Initialize the weights:

token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)

params = model.init(
    jax.random.key(0),
    token_ids,
)

params = params['params']

Inspect the params shape/structure. We can see LoRA weights have been added.

treescope.show(params)

Restore the pre-trained params. We use peft.split_params and peft.merge_params to replace the randomly initialized params with the pre-trained ones.

When using gm.ckpts.load_params, make sure to pass the params=original kwarg. This ensure that:

  • The memory from the old params is released (so only a single copy of the weights stays in memory)

  • The restored params reuse the same sharding as the input (here there’s no sharding, so isn’t required)

# Splits the params into non-LoRA and LoRA weights
original, lora = peft.split_params(params)

# Load the params from the checkpoint
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)

# Merge the pretrained params back with LoRA
params = peft.merge_params(original, lora)

Fine-tuning#

See our finetuning guide for more info.

For a end-to-end finetuning example, see our lora.py config.

Inference#

Here’s an example of running a single model call:

tokenizer = gm.text.Gemma3Tokenizer()

prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Show the token distribution
tokenizer.plot_logits(out.logits)

To sample an entire sentence:

sampler = gm.text.ChatSampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.chat('The capital of France is?')