gm.nn.Output#
- class gemma.gm.nn.Output(
- logits: kauldron.ktyping.array_type_meta.Float['*B L V'] | kauldron.ktyping.array_type_meta.Float['*B V'],
- cache: dict[str, dict[str, jax.Array]] | None,
- hidden_states: kauldron.ktyping.array_type_meta.Float['*B L D'] | kauldron.ktyping.array_type_meta.Float['*B D'] | None,
Bases:
objectOutput of the Gemma model.
- logits
Predicted logits of the model.
- cache
Updated cache if the input cache is not None, None elsewhere.
- Type:
dict[str, dict[str, jax.Array]] | None
- hidden_states
The hidden states of the model.
- logits: kauldron.ktyping.array_type_meta.Float['*B L V'] | kauldron.ktyping.array_type_meta.Float['*B V']
- cache: dict[str, dict[str, jax.Array]] | None
- hidden_states: kauldron.ktyping.array_type_meta.Float['*B L D'] | kauldron.ktyping.array_type_meta.Float['*B D'] | None
- replace(**updates)
Returns a new object replacing the specified fields with new values.