gm.text.ChatSampler

gm.text.ChatSampler#

class gemma.gm.text.ChatSampler(*, model: gemma.gm.nn._transformer_like.TransformerLike, params: collections.abc.Mapping[str, typing.Any], multi_turn: bool = False, print_stream: bool | dialog._src.conversation.Stream = False, tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, stop_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int | None = 4096, max_out_length: int = 2048, pad_length: None | int | tuple[int, ...] = (256, 512, 1024), patch_size: int = 16, max_soft_tokens: int = 1120, pooling_kernel_size: int = 3, audio_sample_rate: int = 16000, audio_seq_length: int = 750, last_state: gemma.gm.text._sampler_loop.SamplingState = None, turns: list[gemma.gm.text._template.Turn] = <factory>)[source]

Bases: object

Chat sampler.

A unified chat sampler that works with all Gemma model versions (2, 3, 3n, 4). Automatically selects the correct underlying sampler and prompt format based on the model’s tokenizer version.

sampler = ChatSampler(
    model=model,
    params=params,
    multi_turn=True,
)

output0 = sampler.chat('Write a poem about cats.')
output1 = sampler.chat('And about dogs.')
output2 = sampler.chat('Which one do you prefer?')

For Gemma 4 models with multimodal inputs:

sampler = ChatSampler(
    model=model,
    params=params,
    multi_turn=True,
)

out0 = sampler.chat('Describe this image <|image|>.', images=[img1])
out1 = sampler.chat('What about this one <|image|>?', images=[img2])
out2 = sampler.chat('Summarize your observations.')

This sampler:

  • Is stateful (the KV-cache state is automatically handled)

  • Automatically formats the prompt with turn tags, adds the BOS (beginning of sequence) token. And filters the end-of-turn tokens from the output.

  • For Gemma 4 models: supports per-turn images (variable aspect ratio) and audio via the images and audio arguments.

model

Gemma transformer model.

Type:

gemma.gm.nn._transformer_like.TransformerLike

params

Model parameters.

Type:

collections.abc.Mapping[str, Any]

multi_turn

If True, reuse the previous turns as context.

Type:

bool

print_stream

If True, will print the sampled output as it is generated.

Type:

bool | dialog._src.conversation.Stream

tokenizer

Tokenizer.

Type:

gemma.gm.text._tokenizer.Tokenizer

sampling

Sampling method to use. Default to greedy sampling.

Type:

gemma.gm.text._sampling.SamplingMethod

forbidden_tokens

List of tokens that are forbidden to be generated. If providing str, it should map to a single token id in the vocab.

Type:

collections.abc.Sequence[str | int] | None

stop_tokens

List of tokens that will stop generation if generated. If providing str, it should map to a single token id in the vocab.

Type:

collections.abc.Sequence[str | int] | None

cache_length

Cache length to use. This is the maximum number of tokens the conversation can have (prompts, answers, images for all turns). Setting this to a fixed value avoids re-compilation between turns.

Type:

int | None

max_out_length

Length of the output buffer for a single turn. Static value used to avoid triggering a jit recompilation. Shouldn’t be changed unless you have a task where the model generates really long outputs.

Type:

int

pad_length

Pad lengths for static shapes (Gemma 4 only).

Type:

None | int | tuple[int, …]

patch_size

Patch size for vision encoder (Gemma 4 only).

Type:

int

max_soft_tokens

Maximum soft tokens per image (Gemma 4 only).

Type:

int

pooling_kernel_size

Pooling kernel size (Gemma 4 only).

Type:

int

audio_sample_rate

Audio sample rate in Hz (Gemma 4 only).

Type:

int

audio_seq_length

Maximum audio sequence length (Gemma 4 only).

Type:

int

last_state

Last state of the sampler, automatically handled by the sampler, but exposed for power users to access the logits, cache, … or initialize the sampler.

Type:

gemma.gm.text._sampler_loop.SamplingState

turns

Track the conversation.

Type:

list[gemma.gm.text._template.Turn]

model: gemma.gm.nn._transformer_like.TransformerLike
params: collections.abc.Mapping[str, Any]
multi_turn: bool = False
print_stream: bool | dialog._src.conversation.Stream = False
tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
sampling: gemma.gm.text._sampling.SamplingMethod
forbidden_tokens: collections.abc.Sequence[str | int] | None = None
stop_tokens: collections.abc.Sequence[str | int] | None = None
cache_length: int | None = 4096
max_out_length: int = 2048
pad_length: None | int | tuple[int, ...] = (256, 512, 1024)
patch_size: int = 16
max_soft_tokens: int = 1120
pooling_kernel_size: int = 3
audio_sample_rate: int = 16000
audio_seq_length: int = 750
last_state: gemma.gm.text._sampler_loop.SamplingState = None
turns: list[gemma.gm.text._template.Turn]
property sampler: gemma.gm.text._sampler.Sampler

Returns the underlying sampler (for backwards compatibility).

property gemma4_sampler: gemma.gm.text._gemma4_sampler.Gemma4Sampler

Returns the underlying Gemma4Sampler (Gemma 4 models only).

chat(
prompt: str | dialog._src.conversation.Conversation,
*,
images: list[numpy.ndarray | PIL.Image.Image] | kauldron.ktyping.array_type_meta.UInt8['N? H W C'] | None = None,
audio: list[numpy.ndarray] | None = None,
audio_lengths: list[int] | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod | None = None,
rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
max_new_tokens: int | None = None,
multi_turn: bool | None = None,
print_stream: bool | dialog._src.conversation.Stream | None = None,
is_legacy_tool_answer: bool = False,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) str[source]

Samples a string from the model.

The API always expects new gemma format tokens (<|image|>, <|audio|>, etc.). The dialog library automatically converts to the correct format for the underlying model.

Example:

# Text-only (all Gemma versions):
output = sampler.chat('Write a poem about cats.')

# With images (Gemma 4 or Gemma 3):
output = sampler.chat(
    'Describe this image <|image|>.',
    images=[image1],
)

# With audio (Gemma 4):
output = sampler.chat(
    'Transcribe this audio <|audio|>.',
    audio=[audio_array],
)
Parameters:
  • prompt – Prompt to sample from. Can be a single string or a dialog.Conversation object.

  • images – Images for the prompt. For Gemma 4: list of raw numpy arrays or PIL Images (variable aspect ratio). For Gemma 2/3: a batched uint8 array.

  • audio – List of audio arrays (Gemma 4 only).

  • audio_lengths – List of audio lengths (Gemma 4 only).

  • sampling – Sampling method to use. If given, will override the default sampling method.

  • rng – Seed to use for the sampling method. If None, a random seed is used. Can be a seed int or a jax.random.PRNGKey object.

  • max_new_tokens – If given, will stop sampling after this many tokens. Used for quicker iterations when debugging. By default, sample until the end-of-turn token is found, or until the max_out_length buffer is filled.

  • multi_turn – If True, reuse the previous turns as context. Overrides the multi_turn attribute.

  • print_stream – If True, will print the sampled output as it is generated. Overrides the print_stream attribute.

  • is_legacy_tool_answer – When True, indicates that the model has emitted <eos> rather than <|tool_response>, thus this needs to be corrected. (this is an internal variable that should never be explicitly set).

  • sharding – Sharding tree (Gemma 4 only).

Returns:

The sampled output.

initialize_stream(
stream: dialog._src.conversation.Stream | bool | None,
) dialog._src.conversation.Stream | None[source]

Initializes a stream for the sampler.

property conversation: dialog._src.conversation.Conversation

Returns the conversation.