gm.nn.TransformerLike#
- class gemma.gm.nn.TransformerLike(*args, **kwargs)[source]
Bases:
ProtocolProtocol for a transformer model to be used with a Sampler.
A model passed to a
Samplermust implement apply and init_cache.- config: gemma.gm.nn._transformer_like.TransformerConfig
- INFO: ClassVar[gemma.gm.nn._transformer_like.ModelInfo]
- init(
- rngs: jax.Array | dict[str, jax.Array],
- *args,
- method: collections.abc.Callable[[...], Any] | str | None = None,
- mutable: bool | str | Collection[str] | flax.core.scope.DenyList = DenyList(deny='intermediates'),
- capture_intermediates: bool | collections.abc.Callable[[flax.linen.module.Module, str], bool] = False,
- **kwargs,
Initializes a module method with variables and returns modified variables.
inittakes as first argument either a singlePRNGKey, or a dictionary mapping variable collections names to theirPRNGKeys, and will callmethod(which is the module’s__call__function by default) passing*argsand**kwargs, and returns a dictionary of initialized variables.- Parameters:
rngs – The PRNGKey or dictionary of PRNGKeys.
*args – Positional arguments to pass to the method.
method – The module method to initialize. Defaults to __call__.
mutable – A filter for which variable collections are mutable.
capture_intermediates – Whether to capture intermediate values.
**kwargs – Keyword arguments to pass to the method.
- abstractmethod apply(
- variables: collections.abc.Mapping[str, collections.abc.Mapping[str, Any]],
- tokens: kauldron.ktyping.array_type_meta.Int['*B L'],
- *,
- images: kauldron.ktyping.array_type_meta.UInt8['*B N H W C'] | kauldron.ktyping.array_type_meta.UInt8['*B H W C'] | None = None,
- cache: dict[str, dict[str, jax.Array]] | None = None,
- positions: kauldron.ktyping.array_type_meta.Int['*B L_with_mm'] | None = None,
- attention_mask: kauldron.ktyping.array_type_meta.Bool['*B L_with_mm cache_length'] | None = None,
Applies a module method to variables and returns output and modified variables.
- abstractmethod init_cache(
- *,
- batch_size: int,
- dtype: numpy.dtype[Any],
- cache_length: int,
- sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
Initializes the KV cache for efficient generation.