gm.testing.DummyGemma

gm.testing.DummyGemma#

class gemma.gm.testing.DummyGemma(
config: gemma.gm.nn._config.TransformerConfig = TransformerConfig(num_embed=13,
embed_dim=32,
hidden_dim=128,
num_heads=2,
head_dim=128,
num_kv_heads=2,
final_logit_softcap=None,
use_post_attn_norm=None,
use_post_ffw_norm=None,
attention_types=(<AttentionType.GLOBAL: 1>,
),
query_pre_attn_norm=<QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM: 1>,
attn_logits_soft_cap=None,
sliding_window_size=None,
transpose_gating_einsum=False,
use_qk_norm=False,
local_base_frequency=10000,
global_base_frequency=10000,
local_scale_factor=1.0,
global_scale_factor=1.0,
vision_encoder=None),
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
*,
return_last_only: bool | None = None,
dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>,
tokens: kontext.Key = '__KEY_REQUIRED__',
images: kontext.Key | None = None,
positions: kontext.Key | None = None,
attention_mask: kontext.Key | None = None,
)[source]

Bases: gemma.gm.nn._transformer.Transformer

Dummy transformer architecture, for testing.

config: gemma.gm.nn._config.TransformerConfig = TransformerConfig(num_embed=13, embed_dim=32, hidden_dim=128, num_heads=2, head_dim=128, num_kv_heads=2, final_logit_softcap=None, use_post_attn_norm=None, use_post_ffw_norm=None, attention_types=(<AttentionType.GLOBAL: 1>,), query_pre_attn_norm=<QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM: 1>, attn_logits_soft_cap=None, sliding_window_size=None, transpose_gating_einsum=False, use_qk_norm=False, local_base_frequency=10000, global_base_frequency=10000, local_scale_factor=1.0, global_scale_factor=1.0, vision_encoder=None)
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: flax.core.scope.Scope | None = None