gm.nn.IntWrapper#
- class gemma.gm.nn.IntWrapper(
- *,
- model: flax.linen.module.Module,
- dtype: numpy.dtype = <class 'jax.numpy.int4'>,
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
- name: str | None = None,
Bases:
flax.linen.module.ModuleWrapper around a Gemma model to enable int4 inference.
The model wrapped will have all it’s nn.Dense, nn.Einsum,… layers replaced by their int4 versions. See gemma.peft documentation for more details.
- model
The model to wrap.
- Type:
flax.linen.module.Module
- model: flax.linen.module.Module
- dtype
alias of
jax.numpy.int4
- name: str | None = None
- parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
- scope: flax.core.scope.Scope | None = None