Quantization Aware Training (QAT)

Quantization Aware Training (QAT)#

Open in Colab

This is an example on how to obtain and run quantized versions of Gemma models. It’s best to first read the finetuning colab to understand this one.

!pip install -q gemma
# Common imports
import os
import treescope
import optax

# Gemma imports
from kauldron import kd
from gemma import gm
from gemma import peft

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Config updates#

If you’re familiar with the finetuning tutorial, switching to QAT only require 1 change to the trainer.

This is slightly different to LoRA (we discuss the difference below)

1. Model#

Wrap the model in the gm.nn.QuantizationAwareTrainingWrapper. This will apply model surgery to replace all the linear and compatible layers with Simulation for Quantized layers. You can choose among several options for quantization:

  • SFP8: switched floating point in 8 bits (very efficient with gemma.cpp)

  • Q4_0: per-block integer quantization (equivalent to 4.5 bits per weights), very popular on llama.cpp

  • INT4: per-channel weight quantization (almost exactly 4 bits per weights)

model = gm.nn.QuantizationAwareWrapper(
    method = peft.QuantizationMethod.INT8,
    model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)

Internally, this uses the gemma.peft mini-library to perform model surgery.

init_transform = wrapped=gm.ckpts.LoadCheckpoint(
    path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)

Training#

Data pipeline#

Like for the finetuning example, we recreate the tokenizer:

tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)
[<_Gemma3SpecialTokens.BOS: 2>, 2094, 563, 614, 2591, 13315]

And the data pipeline:

ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)

We can decode an example from the batch to inspect the model input and check it is properly formatted:

text = tokenizer.decode(ex['input'][0])

print(text)
<start_of_turn>user
Is this a good place to ask about the ethnicity and intelligence debate?<end_of_turn>
<start_of_turn>model
Est-ce un bon endroit pour poser des questions sur le débat à propos de l'ethnicité et le renseignement ?<end_of_turn>

Trainer#

We then create the trainer, reusing the model, init_transform and optimizer created above:

trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optax.adafactor(learning_rate=0.005),
)

Trainning can be launched with the .train() method.

Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.

state, aux = trainer.train()
Configuring ...
Initializing ...
Disabling pygrain multi-processing (unsupported in colab).
Starting training loop at step 0

Inference#

In order to infer the model, you have two options:

  1. simply evaluate the QATWrapper: that does not provide any memory footprint reduction

  2. use the IntWrapper as follows (only available for INT8 quantization)

quantized_model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True))
quantized_params = peft.quantize(state.params, method=peft.QuantizationMethod.INT8)

then evaluate

sampler = gm.text.Sampler(
    model=quantized_model,
    params=quantized_params,
    tokenizer=tokenizer,
)

prompt = """\
<start_of_turn>user
I'm feeling happy!<end_of_turn>
<start_of_turn>model
"""

sampler.sample(prompt, max_new_tokens=30)
'Je me sens bien !<end_of_turn>'