gm.nn.Einsum

gm.nn.Einsum#

class gemma.gm.nn.Einsum(shape: tuple[int, ...], weight_name: str = 'w', initializer: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function normal.<locals>.init>, dtype: numpy.dtype | None = None, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: flax.linen.module.Module

Einsum is a convenience module for parameterized tensor multiplication.

shape: tuple[int, ...]
weight_name: str = 'w'
initializer(
shape: collections.abc.Sequence[int | Any],
dtype: Any | None = None,
out_sharding: jax.sharding.NamedSharding | jax.P | None = None,
) jax.Array
dtype: numpy.dtype | None = None
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: flax.core.scope.Scope | None = None