gm.nn.LoRA

gm.nn.LoRA#

class gemma.gm.nn.LoRA(
*,
rank: int,
model: flax.linen.module.Module,
dtype: numpy.dtype = <class 'jax.numpy.bfloat16'>,
verbose: bool = False,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Wrapper around a Gemma model to enable LoRA.

The model wrapped will have all it’s nn.Dense, nn.Einsum,… layers replaced by their LoRA versions. See gemma.peft documentation for more details.

rank

The rank of the LoRA decomposition.

Type:

int

model

The model to wrap.

Type:

flax.linen.module.Module

dtype

The dtype to use for the LoRA weights.

Type:

numpy.dtype

verbose

If True, logs diagnostic strings for the LoRA layers.

Type:

bool

rank: int
model: flax.linen.module.Module
dtype

alias of jax.numpy.bfloat16

verbose: bool = False
name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: flax.core.scope.Scope | None = None