peft.LoRAEinsumAdapter#
- class gemma.peft.LoRAEinsumAdapter(*, rank: int, einsum_str: str, shape: collections.abc.Sequence[int], dtype: numpy.dtype = <class 'jax.numpy.float64'>, a_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function variance_scaling.<locals>.init>, b_init: jax.nn.initializers.Initializer | collections.abc.Callable[[...], typing.Any] = <function zeros>, parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
flax.linen.module.ModuleLoRA einsum module.
This module only do the x @ A @ B computation. Use
LoRAEinsumto wrap a nn.Einsum layer.- rank
The rank of the LoRA decomposition.
- Type:
int
- einsum_str
The einsum string of the original einsum op. Should be inputs,weights->outputs (this will be internally rewritten as inputs,a,b->outputs)
- Type:
str
- shape
The shape of the original weights before the low-rank adaptation. Should match the weights shape from the einsum_str.
- Type:
collections.abc.Sequence[int]
- dtype
The dtype to use for the LoRA weights.
- Type:
numpy.dtype
- a_init
The initializer for the A matrix.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- b_init
The initializer for the B matrix.
- Type:
jax.nn.initializers.Initializer | collections.abc.Callable[[…], Any]
- rank: int
- einsum_str: str
- shape: collections.abc.Sequence[int]
- dtype
alias of
jax.numpy.float64
- a_init(
- shape: collections.abc.Sequence[int | Any],
- dtype: Any | None = None,
- out_sharding: jax.sharding.NamedSharding | jax.P | None = None,
- b_init(
- shape: collections.abc.Sequence[int | Any],
- dtype: Any | None = None,
- out_sharding: jax.sharding.NamedSharding | jax.P | None = None,
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- setup()[source]
Initializes a Module lazily (similar to a lazy
__init__).setupis called once lazily on a module instance when a module is bound, immediately before any other methods like__call__are invoked, or before asetup-defined attribute onselfis accessed.This can happen in three cases:
Immediately when invoking
apply(),init()orinit_and_output().Once the module is given a name by being assigned to an attribute of another module inside the other module’s
setupmethod (see__setattr__()):>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact(), immediately before another method is called orsetupdefined attribute is accessed.
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: flax.core.scope.Scope | None = None