gm.nn.Gemma4_26B_A4B

gm.nn.Gemma4_26B_A4B#

class gemma.gm.nn.Gemma4_26B_A4B(
config: _config.TransformerConfig = TransformerConfig(num_embed=262144,
embed_dim=2816,
hidden_dim=2112,
num_heads=16,
head_dim=256,
num_kv_heads=8,
final_logit_softcap=30.0,
use_post_attn_norm=True,
use_post_ffw_norm=True,
attention_types=(<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.LOCAL_SLIDING: 2>,
<AttentionType.GLOBAL: 1>),
attn_logits_soft_cap=None,
sliding_window_size=1024,
qk_norm_with_scale=True,
num_global_kv_heads=2,
global_key_size=512,
k_eq_v_global=True,
global_rope_proportion=0.25,
local_rope_proportion=1.0,
local_base_frequency=10000,
global_base_frequency=1000000,
local_scale_factor=1.0,
global_scale_factor=1.0,
per_layer_input_dim=0,
kv_cache_sharing_config=None,
override_kv_shared_ffw_hidden=None,
vision_encoder=VisionEncoder(     # attributes     d_model = 1152     num_layers = 27     num_heads = 16     ffw_hidden = 4304     patch_size = 16     output_length = 280     pos_emb_shape_yx = (10240,
2)     pooling_kernel_size=3     use_clipped_linears = False     standardize_embeddings = True ),
audio_encoder=None,
enable_moe=True,
num_experts=128,
expert_dim=704,
top_k_experts=8,
moe_dense_hidden_dim=2112,
use_bidirectional_attention='vision'),
text_only: bool = True,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
*,
return_last_only: bool | None = None,
dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>,
tokens: kontext.Key = '__KEY_REQUIRED__',
images: kontext.Key | None = None,
)[source]

Bases: gemma.gm.nn.gemma4._gemma4._Gemma4Base

Gemma 4 26B_A4B MoE model.

A Mixture-of-Experts model with 128 experts per layer. Each layer has:

  • A MoE branch (128 experts, expert_dim=704)

  • A dense shared MLP branch (intermediate_dim=2112)

attention_pattern = (AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL)
global_local_pattern = (AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL)
config: gemma.gm.nn.gemma4._config.TransformerConfig = TransformerConfig(num_embed=262144, embed_dim=2816, hidden_dim=2112, num_heads=16, head_dim=256, num_kv_heads=8, final_logit_softcap=30.0, use_post_attn_norm=True, use_post_ffw_norm=True, attention_types=(<AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>), attn_logits_soft_cap=None, sliding_window_size=1024, qk_norm_with_scale=True, num_global_kv_heads=2, global_key_size=512, k_eq_v_global=True, global_rope_proportion=0.25, local_rope_proportion=1.0, local_base_frequency=10000, global_base_frequency=1000000, local_scale_factor=1.0, global_scale_factor=1.0, per_layer_input_dim=0, kv_cache_sharing_config=None, override_kv_shared_ffw_hidden=None, vision_encoder=VisionEncoder(     # attributes     d_model = 1152     num_layers = 27     num_heads = 16     ffw_hidden = 4304     patch_size = 16     output_length = 280     pos_emb_shape_yx = (10240, 2)     pooling_kernel_size = 3     use_clipped_linears = False     standardize_embeddings = True ), audio_encoder=None, enable_moe=True, num_experts=128, expert_dim=704, top_k_experts=8, moe_dense_hidden_dim=2112, use_bidirectional_attention='vision')
INFO: ClassVar[gemma.gm.nn.gemma4._transformer.ModelInfo] = ModelInfo(tokenizer_version=4, default_ckpt=None)
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: flax.core.scope.Scope | None = None