gm.nn.Output

gm.nn.Output#

class gemma.gm.nn.Output(
logits: kauldron.ktyping.array_type_meta.Float['*B L V'] | kauldron.ktyping.array_type_meta.Float['*B V'],
cache: dict[str, dict[str, jax.Array]] | None,
hidden_states: kauldron.ktyping.array_type_meta.Float['*B L D'] | kauldron.ktyping.array_type_meta.Float['*B D'] | None,
)[source]

Bases: object

Output of the Gemma model.

logits

Predicted logits of the model.

Type:

kauldron.ktyping.array_type_meta.Float[’*B L V’] | kauldron.ktyping.array_type_meta.Float[’*B V’]

cache

Updated cache if the input cache is not None, None elsewhere.

Type:

dict[str, dict[str, jax.Array]] | None

hidden_states

The hidden states of the model.

Type:

kauldron.ktyping.array_type_meta.Float[’*B L D’] | kauldron.ktyping.array_type_meta.Float[’*B D’] | None

logits: kauldron.ktyping.array_type_meta.Float['*B L V'] | kauldron.ktyping.array_type_meta.Float['*B V']
cache: dict[str, dict[str, jax.Array]] | None
hidden_states: kauldron.ktyping.array_type_meta.Float['*B L D'] | kauldron.ktyping.array_type_meta.Float['*B D'] | None
replace(**updates)

Returns a new object replacing the specified fields with new values.