gm.nn.IntWrapper

gm.nn.IntWrapper#

class gemma.gm.nn.IntWrapper(
*,
model: flax.linen.module.Module,
dtype: numpy.dtype = <class 'jax.numpy.int4'>,
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>,
name: str | None = None,
)[source]

Bases: flax.linen.module.Module

Wrapper around a Gemma model to enable int4 inference.

The model wrapped will have all it’s nn.Dense, nn.Einsum,… layers replaced by their int4 versions. See gemma.peft documentation for more details.

model

The model to wrap.

Type:

flax.linen.module.Module

model: flax.linen.module.Module
dtype

alias of jax.numpy.int4

name: str | None = None
parent: flax.linen.module.Module | flax.core.scope.Scope | flax.linen.module._Sentinel | None = None
scope: flax.core.scope.Scope | None = None