gm.ckpts.LoadCheckpoint

gm.ckpts.LoadCheckpoint#

class gemma.gm.ckpts.LoadCheckpoint(
path: str | os.PathLike,
quantize: bool = False,
)[source]

Bases: kauldron.checkpoints.partial_loader.InitTransform

Loads weights from a Gemma checkpoint.

Note: The checpoint only contains the Gemma transformer weights, not the step, optimizer state,… Use kd.ckpts.PartialKauldronLoader to load the state from a Kauldron checkpoint.

path

The path to the orbax checkpoint.

Type:

str | os.PathLike

quantize

If True, the params will be mapped to enable quantization aware training.

Type:

bool

path: str | os.PathLike
quantize: bool = False
transform(
state: gemma.gm.ckpts._checkpoint._StateT,
) gemma.gm.ckpts._checkpoint._StateT[source]

Transform the state by updating it with pre-trained values.

Notes:

  • transform functions can modify the state values but should NOT modify its structure, shape or dtypes.

  • transform should correctly propagate the sharding information from the given state.

Parameters:

state – The state object to transform

Returns:

The updated state