gm.nn.Transformer#
- class gemma.gm.nn.Transformer(
- *,
- return_last_only: bool | None = None,
- dtype: numpy.dtype = <class 'jax.numpy.bfloat16'>,
- tokens: Annotated[typing.Any,
- <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]='__KEY_REQUIRED__',
- images: Annotated[typing.Any,
- <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None=None,
- positions: Annotated[typing.Any,
- <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None=None,
- attention_mask: Annotated[typing.Any,
- <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None=None,
- config: gemma.gm.nn._config.TransformerConfig,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleBase transformer class.
- return_last_only
If True, only compute and return the last token. Otherwise, return all logits. Default to False
- Type:
bool | None
- dtype
The parameter dtype. Default to jnp.bfloat16.
- Type:
numpy.dtype
- return_last_only: bool | None = None
- dtype
alias of
jax.numpy.bfloat16
- tokens: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__'
- images: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None = None
- positions: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None = None
- attention_mask: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None = None
- config: gemma.gm.nn._config.TransformerConfig
- INFO: ClassVar[gemma.gm.nn._transformer.ModelInfo] = ModelInfo(tokenizer_version=None, default_ckpt=None)
- setup()[source]
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- init_cache(
- *,
- batch_size: int,
- dtype: numpy.dtype[Any],
- cache_length: int,
- sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: flax.core.scope.Scope | None = None