gm.nn.Gemma3n_E4B#
- class gemma.gm.nn.Gemma3n_E4B(
- config: _config.TransformerConfig = TransformerConfig(num_embed=262144,
- embed_dim=2048,
- hidden_dim=16384,
- num_heads=8,
- head_dim=256,
- num_kv_heads=2,
- final_logit_softcap=None,
- use_post_attn_norm=True,
- use_post_ffw_norm=True,
- attention_types=(<AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.LOCAL_SLIDING: 2>,
- <AttentionType.GLOBAL: 1>),
- query_pre_attn_norm=<QueryPreAttentionNormalisation.NONE: 1>,
- attn_logits_soft_cap=None,
- sliding_window_size=512,
- transpose_gating_einsum=True,
- use_qk_norm=True,
- qk_norm_with_scale=True,
- use_value_norm=True,
- local_base_frequency=10000,
- global_base_frequency=1000000,
- local_scale_factor=1.0,
- global_scale_factor=1.0,
- vision_encoder=SigLiPFromPatches( # attributes siglip_encoder = ViTModel( # attributes patch_size = (14,
- 14) width=1152 depth = 27 mlp_dim = 4304 num_heads = 16 posemb = 'learn' dropout = 0.0 scan = False remat_policy = 'nothing_saveable' dtype_mm = 'float32' ) siglip_exit = VisionExit( # attributes output_length = 256 ) num_mm_tokens_per_image_prepool = 4096 num_mm_tokens_per_image = 256 image_height = 896 image_width = 896 image_channels = 3 apply_stop_gradient = True ),
- use_altup=True,
- num_altup_inputs=4,
- altup_coef_clip=120.0,
- activation_sparsity_pattern=(0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.95,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0),
- per_layer_input_dim=256,
- use_laurel=True,
- laurel_rank=64,
- kv_cache_sharing_config=KVCacheSharingConfig(frac_shared_layers=0.42857142857142855,
- share_global=True,
- share_local=True),
- scale_plus_one=False,
- guard_against_excess_precision=True),
- text_only: bool = False,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
- *,
- return_last_only: bool | None = None,
- dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>,
- tokens: kontext.Key = '__KEY_REQUIRED__',
- images: kontext.Key | None = None,
- positions: kontext.Key | None = None,
- attention_mask: kontext.Key | None = None,
Bases:
gemma.gm.nn.gemma3n._gemma3n._Gemma3nBaseGemma3n E4B transformer architecture.
- config: gemma.gm.nn.gemma3n._config.TransformerConfig = TransformerConfig(num_embed=262144, embed_dim=2048, hidden_dim=16384, num_heads=8, head_dim=256, num_kv_heads=2, final_logit_softcap=None, use_post_attn_norm=True, use_post_ffw_norm=True, attention_types=(<AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.LOCAL_SLIDING: 2>, <AttentionType.GLOBAL: 1>), query_pre_attn_norm=<QueryPreAttentionNormalisation.NONE: 1>, attn_logits_soft_cap=None, sliding_window_size=512, transpose_gating_einsum=True, use_qk_norm=True, qk_norm_with_scale=True, use_value_norm=True, local_base_frequency=10000, global_base_frequency=1000000, local_scale_factor=1.0, global_scale_factor=1.0, vision_encoder=SigLiPFromPatches( # attributes siglip_encoder = ViTModel( # attributes patch_size = (14, 14) width = 1152 depth = 27 mlp_dim = 4304 num_heads = 16 posemb = 'learn' dropout = 0.0 scan = False remat_policy = 'nothing_saveable' dtype_mm = 'float32' ) siglip_exit = VisionExit( # attributes output_length = 256 ) num_mm_tokens_per_image_prepool = 4096 num_mm_tokens_per_image = 256 image_height = 896 image_width = 896 image_channels = 3 apply_stop_gradient = True ), use_altup=True, num_altup_inputs=4, altup_coef_clip=120.0, activation_sparsity_pattern=(0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), per_layer_input_dim=256, use_laurel=True, laurel_rank=64, kv_cache_sharing_config=KVCacheSharingConfig(frac_shared_layers=0.42857142857142855, share_global=True, share_local=True), scale_plus_one=False, guard_against_excess_precision=True)
- INFO: ClassVar[gemma.gm.nn.gemma3n._transformer.ModelInfo] = ModelInfo(tokenizer_version='3n', default_ckpt=None)
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: flax.core.scope.Scope | None = None