Multi-modal#
Example on how to use Gemma models for multi-modal.
!pip install -q gemma
# Common imports
import os
import jax.numpy as jnp
import tensorflow_datasets as tfds
# Gemma imports
from gemma import gm
By default, Jax do not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
First, let’s load an image:
ds = tfds.data_source('oxford_flowers102', split='train')
image = ds[0]['image']
image
ndarray (500, 667, 3)
array([[[1, 1, 0],
[1, 1, 0],
[1, 1, 0],
...,
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[1, 1, 0],
[1, 1, 0],
[1, 1, 0],
...,
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
[[1, 1, 0],
[1, 1, 0],
[1, 1, 0],
...,
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]],
...,
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
...,
[7, 7, 5],
[6, 6, 4],
[5, 5, 3]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
...,
[7, 7, 5],
[6, 6, 4],
[5, 5, 3]],
[[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
...,
[7, 7, 5],
[6, 6, 4],
[5, 5, 3]]], shape=(500, 667, 3), dtype=uint8)Load the model and params.
model = gm.nn.Gemma3_4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
Sampling full prompt#
To use the multi-modal capabilities, simply:
In the prompt: Add the
<start_of_image>special tokens, where the images should be inserted.Pass the image(s) to the
images=argument of the sampler
sampler = gm.text.ChatSampler(
model=model,
params=params,
)
out = sampler.chat(
'What can you say about this image: <start_of_image>',
images=image,
)
print(out)
Here's a breakdown of what I can say about the image:
**Overall Impression:**
The image is a stunning, close-up photograph of a water lily in full bloom. It’s dramatically lit, creating a strong contrast between light and shadow, which really emphasizes the flower's form and texture.
**Specific Details:**
* **Flower Type:** It appears to be a Nymphaea (water lily). The shape of the petals and the prominent stamens are characteristic of this type of flower.
* **Color:** The petals are primarily white with a subtle pinkish hue at the base. The stamens are a bright, vibrant yellow.
* **Lighting:** The lighting is key. There's a strong light source coming from the upper left, casting dramatic shadows and highlighting the edges of the petals. This creates a sense of depth and makes the flower appear almost sculptural.
* **Texture:** You can see the delicate texture of the petals – they appear smooth but with subtle ridges and folds.
* **Composition:** The flower is centered in the frame, drawing the viewer's eye directly to it. The dark background isolates the flower and makes it the focal point.
* **Water Droplets:** There are a few water droplets on the petals, adding a touch of freshness and realism.
**Mood/Feeling:**
The image evokes a feeling of tranquility, beauty, and perhaps a touch of mystery due to the dramatic lighting. It feels serene and peaceful.
**Do you want me to focus on a specific aspect of the image, such as:**
* The lighting technique?
* The flower's anatomy?
* The overall mood it creates?<end_of_turn>
Notes:
The model was trained on
jpegimages. If you have PNG images, those should be encoded/decoded to Jpeg, to avoid bias.You can pass multiple images. Just add
<start_of_image>everywhere an image should be inserted. All images should be resized to the same shape. Input shape would then bebatch, num_images, h, w, c(instead ofbatch, h, w, c).If prompts within a batch have different number of images, just pad the tensor with 0 (or any) values for unused images.
Calling model directly#
Adding images to the model only require to:
In the prompt: Add
<start_of_image>special tokens where the images should be inserted.
tokenizer = gm.text.Gemma3Tokenizer()
prompt = """<start_of_turn>user
Describe this image in a single word.
<start_of_image>
<end_of_turn>
<start_of_turn>model
"""
prompt = jnp.asarray(tokenizer.encode(prompt, add_bos=True))
In the model: pass the
images=tomodel.apply.
# Run the model
out = model.apply(
{'params': params},
tokens=prompt,
images=image,
return_last_only=True, # Only predict the last token
)
# Plot the probability distribution
tokenizer.plot_logits(out.logits)
Finetuning#
Finetuning with multi-modal is also simple. From the original finetuning, changing to multi-modal only require 2 changes:
Have a dataset which also return an image (
b h w c), or multiple images (b n h w c)Specify the model input which field in the batch correspond to the images:
model = gm.nn.Gemma3_4B( tokens='batch.tokens', images='batch.image', )
See the multimodal.py example.