Quantization Aware Training (QAT)#
This is an example on how to obtain and run quantized versions of Gemma models. It’s best to first read the finetuning colab to understand this one.
!pip install -q gemma
# Common imports
import os
import treescope
import optax
# Gemma imports
from kauldron import kd
from gemma import gm
from gemma import peft
By default, Jax do not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
Config updates#
If you’re familiar with the finetuning tutorial, switching to QAT only require 1 change to the trainer.
This is slightly different to LoRA (we discuss the difference below)
1. Model#
Wrap the model in the gm.nn.QuantizationAwareTrainingWrapper. This will apply model surgery to replace all the linear and compatible layers with Simulation for Quantized layers. You can choose among several options for quantization:
SFP8: switched floating point in 8 bits (very efficient with gemma.cpp)
Q4_0: per-block integer quantization (equivalent to 4.5 bits per weights), very popular on llama.cpp
INT4: per-channel weight quantization (almost exactly 4 bits per weights)
model = gm.nn.QuantizationAwareWrapper(
method = peft.QuantizationMethod.INT8,
model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)
Internally, this uses the gemma.peft mini-library to perform model surgery.
init_transform = wrapped=gm.ckpts.LoadCheckpoint(
path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)
Training#
Data pipeline#
Like for the finetuning example, we recreate the tokenizer:
tokenizer = gm.text.Gemma3Tokenizer()
tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma3SpecialTokens.BOS: 2>, 2094, 563, 614, 2591, 13315]
And the data pipeline:
ds = kd.data.py.Tfds(
name='mtnt/en-fr',
split='train',
shuffle=True,
batch_size=8,
transforms=[
# Create the model inputs/targets/loss_mask.
gm.data.Seq2SeqTask(
# Select which field from the dataset to use.
# https://www.tensorflow.org/datasets/catalog/mtnt
in_prompt='src',
in_response='dst',
# Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
out_input='input',
out_target='target',
out_target_mask='loss_mask',
tokenizer=tokenizer,
# Padding parameters
max_length=200,
truncate=True,
),
],
)
ex = ds[0]
treescope.show(ex)
We can decode an example from the batch to inspect the model input and check it is properly formatted:
text = tokenizer.decode(ex['input'][0])
print(text)
<start_of_turn>user
Is this a good place to ask about the ethnicity and intelligence debate?<end_of_turn>
<start_of_turn>model
Est-ce un bon endroit pour poser des questions sur le débat à propos de l'ethnicité et le renseignement ?<end_of_turn>
Trainer#
We then create the trainer, reusing the model, init_transform and optimizer created above:
trainer = kd.train.Trainer(
seed=42, # The seed of enlightenment
workdir='/tmp/ckpts', # TODO(epot): Make the workdir optional by default
# Dataset
train_ds=ds,
# Model
model=model,
init_transform=init_transform,
# Training parameters
num_train_steps=500,
train_losses={
"loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
logits="preds.logits",
labels="batch.target",
mask="batch.loss_mask",
),
},
optimizer=optax.adafactor(learning_rate=0.005),
)
Trainning can be launched with the .train() method.
Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.
state, aux = trainer.train()
Configuring ...
Initializing ...
Disabling pygrain multi-processing (unsupported in colab).
Starting training loop at step 0
Inference#
In order to infer the model, you have two options:
simply evaluate the
QATWrapper: that does not provide any memory footprint reductionuse the
IntWrapperas follows (only available for INT8 quantization)
quantized_model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True))
quantized_params = peft.quantize(state.params, method=peft.QuantizationMethod.INT8)
then evaluate
sampler = gm.text.Sampler(
model=quantized_model,
params=quantized_params,
tokenizer=tokenizer,
)
prompt = """\
<start_of_turn>user
I'm feeling happy!<end_of_turn>
<start_of_turn>model
"""
sampler.sample(prompt, max_new_tokens=30)
'Je me sens bien !<end_of_turn>'