gm.text.Sampler

gm.text.Sampler#

class gemma.gm.text.Sampler(*, model: gemma.gm.nn._transformer_like.TransformerLike, params: collections.abc.Mapping[str, typing.Any], tokenizer: gemma.gm.text._tokenizer.Tokenizer = None, sampling: gemma.gm.text._sampling.SamplingMethod = <factory>, forbidden_tokens: collections.abc.Sequence[str | int] | None = None, stop_tokens: collections.abc.Sequence[str | int] | None = None, cache_length: int = 4096, max_out_length: int = 2048, pad_length: None | int | tuple[int, ...] = (256, 512, 1024))[source]

Bases: object

Sampler.

This is a low-level API. For most use cases, prefer gm.text.ChatSampler instead.

sampler = Sampler(
    model=model,
    params=params,
)

output = sampler.sample(prompt)

This sampler:

  • Is stateless (state has to be manually forwarded between calls)

  • User has to manually format the prompt using <start_of_turn>,…

  • The BOS (beginning of sequence) token is automatically added.

model

Gemma transformer model.

Type:

gemma.gm.nn._transformer_like.TransformerLike

params

Model parameters.

Type:

collections.abc.Mapping[str, Any]

tokenizer

Tokenizer.

Type:

gemma.gm.text._tokenizer.Tokenizer

sampling

Sampling method to use. Default to greedy sampling.

Type:

gemma.gm.text._sampling.SamplingMethod

forbidden_tokens

List of tokens that are forbidden to be generated. If providing str, it should map to a single token id in the vocab.

Type:

collections.abc.Sequence[str | int] | None

stop_tokens

List of tokens that will stop generation if generated. If providing str, it should map to a single token id in the vocab.

Type:

collections.abc.Sequence[str | int] | None

cache_length

Cache length to use. This is the maximum number of tokens the conversation can have (prompts, answers, images for all turns). Setting this to a fixed value avoids re-compilation between turns.

Type:

int

max_out_length

Length of the output buffer for a single turn. Static value used to avoid trigering a jit recompilation. Shouldn’t be changed unless you have a task where the model generates really long outputs.

Type:

int

pad_length

If provided, pad the prompt to this length. This ensure the prompt is always the same length, to avoid jit re-compilation.

Type:

None | int | tuple[int, …]

model: gemma.gm.nn._transformer_like.TransformerLike
params: collections.abc.Mapping[str, Any]
tokenizer: gemma.gm.text._tokenizer.Tokenizer = None
sampling: gemma.gm.text._sampling.SamplingMethod
forbidden_tokens: collections.abc.Sequence[str | int] | None = None
stop_tokens: collections.abc.Sequence[str | int] | None = None
cache_length: int = 4096
max_out_length: int = 2048
pad_length: None | int | tuple[int, ...] = (256, 512, 1024)
sample(
prompt: str | dialog._src.conversation.Conversation,
*,
images: kauldron.ktyping.array_type_meta.UInt8['N? H W C'] | None = None,
max_new_tokens: int | None = None,
stream: Literal[False] = False,
sampling: gemma.gm.text._sampling.SamplingMethod = None,
rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: Literal[False] = False,
last_state: gemma.gm.text._sampler_loop.SamplingState | None = None,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) str[source]
sample(
prompt: collections.abc.Sequence[str | dialog._src.conversation.Conversation],
*,
images: collections.abc.Sequence[kauldron.ktyping.array_type_meta.UInt8['N H W C']] | None = None,
max_new_tokens: int | None = None,
stream: Literal[False] = False,
sampling: gemma.gm.text._sampling.SamplingMethod = None,
rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: Literal[False] = False,
last_state: gemma.gm.text._sampler_loop.SamplingState | None = None,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) list[str]
sample(
prompt: gemma.gm.text._sampler._Prompt,
*,
images: kauldron.ktyping.array_type_meta.UInt8['B? N? H W C'] | None = None,
max_new_tokens: int | None = None,
stream: Literal[False] = False,
sampling: gemma.gm.text._sampling.SamplingMethod = None,
rng: int | collections.abc.Sequence[int] | numpy.ndarray | jaxtyping.UInt32[Array, '2'] | jaxtyping.UInt32[ndarray, '2'] | jax.Array | None = None,
return_state: Literal[True] = False,
last_state: gemma.gm.text._sampler_loop.SamplingState | None = None,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) gemma.gm.text._sampler.SamplerOutput

Samples a string from the model.

Example:

prompt = """<start_of_turn>user
I'm hesitating between those two options:

Option 1: <start_of_image>
Option 2: <start_of_image>

Any thoughts ?
<end_of_turn>
<start_of_turn>model
"""
sampler.sample(prompt, images=images))
Parameters:
  • prompt – Prompt(s) to sample from. Can be a single string or dialog.Conversation or a list of those.

  • images – Images for the prompt. The position where the image should be inserted in the prompt is determined by the <start_of_image> token in the prompt.

  • max_new_tokens – Maximum number of new tokens to generate. The transformer will process input_length + max_new_tokens.

  • stream – If True, yields tokens as they get predicted.

  • sampling – Sampling method to use. If given, will override the sampling method provided in __init__ (default: greedy).

  • rng – Seed to use for the sampling method. If None, a random seed is used. Can be a seed int or a jax.random.PRNGKey object.

  • return_state – If True, returns SamplerOutput object with additional values of the output (logits, cache,…).

  • last_state – When return_state=True, the state can be propagated across calls to the sampler, for multi-turn conversations. Use gm.text.ChatSampler for a simpler API which handles the state for you.

  • sharding – If provided, shard the tokens according to the specified sharding. Users are responsible for ensuring the tokenized prompt is compatible with the sharding. For example, if sharding=kd.sharding.FIRST_DIM, the number of prompts must be divisible by the number of devices.

Returns:

The sampled output.