gm.text.TopkSampling

gm.text.TopkSampling#

class gemma.gm.text.TopkSampling(*, temperature: float = 1.0, k: int = 1)[source]

Bases: gemma.gm.text._sampling.SamplingMethod

Top-k sampling.

temperature: float = 1.0
k: int = 1
get_next_tokens(
logits: kauldron.ktyping.array_type_meta.Float['*B V'],
rng: kauldron.ktyping.array_type_meta.UInt32['2'] | kauldron.ktyping.array_type_meta.Fry[''] | kauldron.ktyping.array_type_meta.KdPRNGKey,
) kauldron.ktyping.array_type_meta.Int['*B'][source]

Returns the next tokens to generate.

Parameters:
  • logits – Logits, as returned by the model (i.e. before softmax).

  • rng – A random key.

Returns:

The next tokens to generate.