gm.nn.Transformer

gm.nn.Transformer#

class gemma.gm.nn.Transformer(
*,
return_last_only: bool | None = None,
dtype: numpy.dtype = <class 'jax.numpy.bfloat16'>,
tokens: Annotated[typing.Any,
<kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]='__KEY_REQUIRED__',
images: Annotated[typing.Any,
<kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None=None,
positions: Annotated[typing.Any,
<kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None=None,
attention_mask: Annotated[typing.Any,
<kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None=None,
config: gemma.gm.nn._config.TransformerConfig,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Base transformer class.

return_last_only

If True, only compute and return the last token. Otherwise, return all logits. Default to False

Type:

bool | None

dtype

The parameter dtype. Default to jnp.bfloat16.

Type:

numpy.dtype

return_last_only: bool | None = None
dtype

alias of jax.numpy.bfloat16

tokens: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__'
images: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None = None
positions: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None = None
attention_mask: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] | None = None
config: gemma.gm.nn._config.TransformerConfig
INFO: ClassVar[gemma.gm.nn._transformer.ModelInfo] = ModelInfo(tokenizer_version=None, default_ckpt=None)
setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

init_cache(
*,
batch_size: int,
dtype: numpy.dtype[Any],
cache_length: int,
sharding: kauldron.ktyping.pytree.PyTree[None | Sharding | Callable[list, str]] | None = None,
) dict[str, dict[str, jax.Array]][source]
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: flax.core.scope.Scope | None = None