gm.text.Gemma4Sampler#
- class gemma.gm.text.Gemma4Sampler(*, model: gemma.gm.nn._transformer_like.TransformerLike, params: collections.abc.Mapping[str, typing.Any], tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, stop_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int = 4096, max_out_length: int = 2048, pad_length: None | int | tuple[int, ...] = (256, 512, 1024), patch_size: int = 16, max_soft_tokens: int = 1120, pooling_kernel_size: int = 3, audio_sample_rate: int = 16000, audio_seq_length: int = 750)[source]
Bases:
objectStateless sampler for Gemma4 with variable-aspect-ratio image support.
This sampler handles preprocessing of variable-size images internally, expanding each <|image|> token to the correct number of placeholder tokens based on each image’s actual dimensions.
- model
Gemma transformer model.
- Type:
gemma.gm.nn._transformer_like.TransformerLike
- params
Model parameters.
- Type:
collections.abc.Mapping[str, Any]
- tokenizer
Tokenizer.
- Type:
gemma.gm.text._tokenizer.Tokenizer
- sampling
Sampling method. Default to greedy.
- Type:
gemma.gm.text._sampling.SamplingMethod
- forbidden_tokens
List of tokens forbidden from generation.
- Type:
collections.abc.Sequence[str | int] | None
- stop_tokens
List of tokens that stop generation.
- Type:
collections.abc.Sequence[str | int] | None
- cache_length
Maximum cache length.
- Type:
int
- max_out_length
Maximum output buffer length.
- Type:
int
- pad_length
Pad lengths for static shapes.
- Type:
None | int | tuple[int, …]
- patch_size
Patch size for vision encoder.
- Type:
int
- max_soft_tokens
Maximum soft tokens per image.
- Type:
int
- pooling_kernel_size
Pooling kernel size.
- Type:
int
- audio_sample_rate
Audio sample rate in Hz.
- Type:
int
- audio_seq_length
Maximum audio sequence length.
- Type:
int
- model: gemma.gm.nn._transformer_like.TransformerLike
- params: collections.abc.Mapping[str, Any]
- tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
- sampling: gemma.gm.text._sampling.SamplingMethod
- forbidden_tokens: collections.abc.Sequence[str | int] | None = None
- stop_tokens: collections.abc.Sequence[str | int] | None = None
- cache_length: int = 4096
- max_out_length: int = 2048
- pad_length: None | int | tuple[int, ...] = (256, 512, 1024)
- patch_size: int = 16
- max_soft_tokens: int = 1120
- pooling_kernel_size: int = 3
- audio_sample_rate: int = 16000
- audio_seq_length: int = 750
- sample(
- prompt: str | collections.abc.Sequence[str],
- *,
- images: list[numpy.ndarray | PIL.Image.Image] | None = None,
- audio: list[numpy.ndarray] | None = None,
- audio_lengths: list[int] | None = None,
- max_new_tokens: int | None = None,
- sampling: gemma.gm.text._sampling.SamplingMethod | None = None,
- rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
- return_state: bool = False,
- last_state: gemma.gm.text._sampler_loop.SamplingState | None = None,
- sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
Samples from the model with variable-aspect-ratio image support.
- Parameters:
prompt – Text prompt(s). Use <|image|> as image placeholder.
images – List of raw images (numpy arrays or PIL Images) of any size.
audio – List of audio arrays or None.
audio_lengths – List of audio lengths or None.
max_new_tokens – Maximum new tokens to generate.
sampling – Sampling method override.
rng – Random seed or PRNGKey.
return_state – If True, return SamplerOutput with state.
last_state – Previous state for multi-turn.
sharding – Sharding tree.
- Returns:
The sampled output.