gm.text.ToolSampler

gm.text.ToolSampler#

class gemma.gm.text.ToolSampler(*, model: gemma.gm.nn._transformer_like.TransformerLike, params: collections.abc.Mapping[str, typing.Any], multi_turn: bool = False, print_stream: bool | dialog._src.conversation.Stream = False, tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, stop_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int | None = 4096, max_out_length: int = 2048, pad_length: None | int | tuple[int, ...] = (256, 512, 1024), patch_size: int = 16, max_soft_tokens: int = 1120, pooling_kernel_size: int = 3, audio_sample_rate: int = 16000, audio_seq_length: int = 750, last_state: gemma.gm.text._sampler_loop.SamplingState = None, turns: list[gemma.gm.text._template.Turn] = <factory>, tool_handler: gemma.gm.tools._manager.ToolHandlerBase)[source]

Bases: gemma.gm.text._chat_sampler.ChatSampler

Sampler with tool support.

Example:

sampler = gm.text.ToolSampler(
    model=model,
    params=params,
    tool_handler=fastmcp.Client(server),
)
sampler.chat('Do you see an issue with my ~/.bashrc ?')
tool_handler

Allow to customize how the system prompt and tools are handled.

Type:

gemma.gm.tools._manager.ToolHandlerBase

tool_handler: gemma.gm.tools._manager.ToolHandlerBase
chat(
prompt: str | dialog._src.conversation.Conversation,
*,
images: list[numpy.ndarray | PIL.Image.Image] | kauldron.ktyping.array_type_meta.UInt8['N? H W C'] | None = None,
audio: list[numpy.ndarray] | None = None,
audio_lengths: list[int] | None = None,
sampling: gemma.gm.text._sampling.SamplingMethod | None = None,
rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
max_new_tokens: int | None = None,
multi_turn: bool | None = None,
print_stream: bool | dialog._src.conversation.Stream | None = None,
is_legacy_tool_answer: bool = False,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) str[source]

Sampler which supports tool use.

Parameters:
  • prompt – Prompt to sample from. Can be a single string or a list of strings.

  • images – Images for the prompt. For Gemma 4: list of raw numpy arrays or PIL Images (variable aspect ratio). For Gemma 2/3: a batched uint8 array.

  • audio – List of audio arrays (Gemma 4 only).

  • audio_lengths – List of audio lengths (Gemma 4 only).

  • sampling – Sampling method to use. If given, will override the default sampling method.

  • rng – Seed to use for the sampling method. If None, a random seed is used. Can be a seed int or a jax.random.PRNGKey object.

  • max_new_tokens – If given, will stop sampling after this many tokens. Used for quicker iterations when debugging. By default, sample until the <end_of_turn> token is found, or until the max_out_length buffer is filled.

  • multi_turn – If True, reuse the previous turns as context. Overrites the multi_turn attribute.

  • print_stream – If True, will print the sampled output as it is generated. Overrites the multi_turn attribute.

  • is_legacy_tool_answer – When True, indicates that the model has emitted <eos> rather than <|tool_response>, thus this needs to be corrected. (this is an internal variable that should never be explictly set).

  • sharding – Sharding tree (Gemma 4 only).

Returns:

The sampled output.