gm.text.SamplingMethod

gm.text.SamplingMethod#

class gemma.gm.text.SamplingMethod[source]

Bases: abc.ABC

Base class for sampling methods.

abstractmethod get_next_tokens(
logits: kauldron.ktyping.array_type_meta.Float['*B V'],
rng: kauldron.ktyping.array_type_meta.UInt32['2'] | kauldron.ktyping.array_type_meta.Fry[''] | kauldron.ktyping.array_type_meta.KdPRNGKey,
) kauldron.ktyping.array_type_meta.Int['*B'][source]

Returns the next tokens to generate.

Parameters:
  • logits – Logits, as returned by the model (i.e. before softmax).

  • rng – A random key.

Returns:

The next tokens to generate.