LoRA (Sampling)#
Example on using LoRA with Gemma (for inference). For an example of fine-tuning with LoRA, see LoRA finetuning example.
!pip install -q gemma
# Common imports
import os
import jax
import jax.numpy as jnp
import treescope
# Gemma imports
from gemma import gm
from gemma import peft # Parameter fine-tuning module
By default, Jax do not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
Initializing the model#
To use Gemma with LoRA, simply wrap any Gemma model in gm.nn.LoRA:
model = gm.nn.LoRA(
rank=4,
model=gm.nn.Gemma3_4B(text_only=True),
)
Initialize the weights:
token_ids = jnp.zeros((1, 256,), dtype=jnp.int32) # Create the (batch_size, seq_length)
params = model.init(
jax.random.key(0),
token_ids,
)
params = params['params']
Inspect the params shape/structure. We can see LoRA weights have been added.
treescope.show(params)
Restore the pre-trained params. We use peft.split_params and peft.merge_params to replace the randomly initialized params with the pre-trained ones.
When using gm.ckpts.load_params, make sure to pass the params=original kwarg. This ensure that:
The memory from the old params is released (so only a single copy of the weights stays in memory)
The restored params reuse the same sharding as the input (here there’s no sharding, so isn’t required)
# Splits the params into non-LoRA and LoRA weights
original, lora = peft.split_params(params)
# Load the params from the checkpoint
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)
# Merge the pretrained params back with LoRA
params = peft.merge_params(original, lora)
Fine-tuning#
See our finetuning guide for more info.
For a end-to-end finetuning example, see our lora.py config.
Inference#
Here’s an example of running a single model call:
tokenizer = gm.text.Gemma3Tokenizer()
prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)
# Run the model
out = model.apply(
{'params': params},
tokens=prompt,
return_last_only=True, # Only predict the last token
)
# Show the token distribution
tokenizer.plot_logits(out.logits)
To sample an entire sentence:
sampler = gm.text.ChatSampler(
model=model,
params=params,
tokenizer=tokenizer,
)
sampler.chat('The capital of France is?')