gm.nn.TransformerLike

gm.nn.TransformerLike#

class gemma.gm.nn.TransformerLike(*args, **kwargs)[source]

Bases: Protocol

Protocol for a transformer model to be used with a Sampler.

A model passed to a Sampler must implement apply and init_cache.

config: gemma.gm.nn._transformer_like.TransformerConfig
INFO: ClassVar[gemma.gm.nn._transformer_like.ModelInfo]
init(
rngs: jax.Array | dict[str, jax.Array],
*args,
method: collections.abc.Callable[[...], Any] | str | None = None,
mutable: bool | str | Collection[str] | flax.core.scope.DenyList = DenyList(deny='intermediates'),
capture_intermediates: bool | collections.abc.Callable[[flax.linen.module.Module, str], bool] = False,
**kwargs,
) flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]] | dict[str, Any][source]

Initializes a module method with variables and returns modified variables.

init takes as first argument either a single PRNGKey, or a dictionary mapping variable collections names to their PRNGKeys, and will call method (which is the module’s __call__ function by default) passing *args and **kwargs, and returns a dictionary of initialized variables.

Parameters:
  • rngs – The PRNGKey or dictionary of PRNGKeys.

  • *args – Positional arguments to pass to the method.

  • method – The module method to initialize. Defaults to __call__.

  • mutable – A filter for which variable collections are mutable.

  • capture_intermediates – Whether to capture intermediate values.

  • **kwargs – Keyword arguments to pass to the method.

abstractmethod apply(
variables: collections.abc.Mapping[str, collections.abc.Mapping[str, Any]],
tokens: kauldron.ktyping.array_type_meta.Int['*B L'],
*,
images: kauldron.ktyping.array_type_meta.UInt8['*B N H W C'] | kauldron.ktyping.array_type_meta.UInt8['*B H W C'] | None = None,
cache: dict[str, dict[str, jax.Array]] | None = None,
positions: kauldron.ktyping.array_type_meta.Int['*B L_with_mm'] | None = None,
attention_mask: kauldron.ktyping.array_type_meta.Bool['*B L_with_mm cache_length'] | None = None,
) Any | tuple[Any, flax.core.frozen_dict.FrozenDict[str, collections.abc.Mapping[str, Any]] | dict[str, Any]][source]

Applies a module method to variables and returns output and modified variables.

abstractmethod init_cache(
*,
batch_size: int,
dtype: numpy.dtype[Any],
cache_length: int,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) dict[str, dict[str, jax.Array]][source]

Initializes the KV cache for efficient generation.