gm.nn.Embedder#
- class gemma.gm.nn.Embedder(
- vocab_size: int,
- embed_dim: int,
- vision_proj_dim: int | None = None,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleEmbedder module.
- vocab_size: int
- embed_dim: int
- vision_proj_dim: int | None = None
- setup()[source]
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- encode(x: jax.Array) jax.Array[source]
Encodes the input tokens.
- Parameters:
x – Input tokens of shape [seq_len] or [batch_size, seq_len], where each token is an integer in [0, vocab_size).
- Returns:
Encoded tokens of shape [seq_len, embed_dim] or [batch_size, seq_len, embed_dim].
- decode(x: jax.Array) jax.Array[source]
Decodes the input vectors.
- Parameters:
x – Array of shape [seq_len, embed_dim] or [batch_size, seq_len, embed_dim].
- Returns:
Array of shape [seq_len, vocab_size] or [batch_size, seq_len, vocab_size].
- encode_vision(x: jax.Array) jax.Array[source]
Projects siglip embeddings to the embedding space of the text encoder.
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: flax.core.scope.Scope | None = None