gm.nn.LoRA#
- class gemma.gm.nn.LoRA(
- *,
- rank: int,
- model: flax.linen.module.Module,
- dtype: numpy.dtype = <class 'jax.numpy.bfloat16'>,
- verbose: bool = False,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleWrapper around a Gemma model to enable LoRA.
The model wrapped will have all it’s nn.Dense, nn.Einsum,… layers replaced by their LoRA versions. See gemma.peft documentation for more details.
- rank
The rank of the LoRA decomposition.
- Type:
int
- model
The model to wrap.
- Type:
flax.linen.module.Module
- dtype
The dtype to use for the LoRA weights.
- Type:
numpy.dtype
- verbose
If True, logs diagnostic strings for the LoRA layers.
- Type:
bool
- rank: int
- model: flax.linen.module.Module
- dtype
alias of
jax.numpy.bfloat16
- verbose: bool = False
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: flax.core.scope.Scope | None = None