gm.text.Gemma4Sampler

gm.text.Gemma4Sampler#

class gemma.gm.text.Gemma4Sampler(*, model: gemma.gm.nn._transformer_like.TransformerLike, params: collections.abc.Mapping[str, typing.Any], tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, stop_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int = 4096, max_out_length: int = 2048, pad_length: None | int | tuple[int, ...] = (256, 512, 1024), patch_size: int = 16, max_soft_tokens: int = 1120, pooling_kernel_size: int = 3, audio_sample_rate: int = 16000, audio_seq_length: int = 750)[source]

Bases: object

Stateless sampler for Gemma4 with variable-aspect-ratio image support.

This sampler handles preprocessing of variable-size images internally, expanding each <|image|> token to the correct number of placeholder tokens based on each image’s actual dimensions.

model

Gemma transformer model.

Type:

gemma.gm.nn._transformer_like.TransformerLike

params

Model parameters.

Type:

collections.abc.Mapping[str, Any]

tokenizer

Tokenizer.

Type:

gemma.gm.text._tokenizer.Tokenizer

sampling

Sampling method. Default to greedy.

Type:

gemma.gm.text._sampling.SamplingMethod

forbidden_tokens

List of tokens forbidden from generation.

Type:

collections.abc.Sequence[str | int] | None

stop_tokens

List of tokens that stop generation.

Type:

collections.abc.Sequence[str | int] | None

cache_length

Maximum cache length.

Type:

int

max_out_length

Maximum output buffer length.

Type:

int

pad_length

Pad lengths for static shapes.

Type:

None | int | tuple[int, …]

patch_size

Patch size for vision encoder.

Type:

int

max_soft_tokens

Maximum soft tokens per image.

Type:

int

pooling_kernel_size

Pooling kernel size.

Type:

int

audio_sample_rate

Audio sample rate in Hz.

Type:

int

audio_seq_length

Maximum audio sequence length.

Type:

int

model: gemma.gm.nn._transformer_like.TransformerLike
params: collections.abc.Mapping[str, Any]
tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
sampling: gemma.gm.text._sampling.SamplingMethod
forbidden_tokens: collections.abc.Sequence[str | int] | None = None
stop_tokens: collections.abc.Sequence[str | int] | None = None
cache_length: int = 4096
max_out_length: int = 2048
pad_length: None | int | tuple[int, ...] = (256, 512, 1024)
patch_size: int = 16
max_soft_tokens: int = 1120
pooling_kernel_size: int = 3
audio_sample_rate: int = 16000
audio_seq_length: int = 750
sample(
prompt: str | collections.abc.Sequence[str],
*,
images: list[numpy.ndarray | PIL.Image.Image] | None = None,
audio: list[numpy.ndarray] | None = None,
audio_lengths: list[int] | None = None,
max_new_tokens: int | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod | None = None,
rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: bool = False,
last_state: gemma.gm.text._sampler_loop.SamplingState | None = None,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) str | list[str] | gemma.gm.text._sampler.SamplerOutput[source]

Samples from the model with variable-aspect-ratio image support.

Parameters:
  • prompt – Text prompt(s). Use <|image|> as image placeholder.

  • images – List of raw images (numpy arrays or PIL Images) of any size.

  • audio – List of audio arrays or None.

  • audio_lengths – List of audio lengths or None.

  • max_new_tokens – Maximum new tokens to generate.

  • sampling – Sampling method override.

  • rng – Random seed or PRNGKey.

  • return_state – If True, return SamplerOutput with state.

  • last_state – Previous state for multi-turn.

  • sharding – Sharding tree.

Returns:

The sampled output.