gm.ckpts.AnchoredPolicyLoader

gm.ckpts.AnchoredPolicyLoader#

class gemma.gm.ckpts.AnchoredPolicyLoader(
*,
policy: kauldron.checkpoints.partial_loader.InitTransform,
anchor: kauldron.checkpoints.partial_loader.InitTransform | None = None,
)[source]

Bases: kauldron.checkpoints.partial_loader.InitTransform

Loader for gm.nn.AnchoredPolicy models.

Loaded load policy and anchor separately by providing sub-transforms.

This assume the sub-loaders only overwrite the state.params without modifying the rest of the state.

policy: kauldron.checkpoints.partial_loader.InitTransform
anchor: kauldron.checkpoints.partial_loader.InitTransform | None = None
transform(
state: kauldron.train.train_step.TrainState,
) kauldron.train.train_step.TrainState[source]

Transform the state by updating it with pre-trained values.

Notes:

  • transform functions can modify the state values but should NOT modify its structure, shape or dtypes.

  • transform should correctly propagate the sharding information from the given state.

Parameters:

state – The state object to transform

Returns:

The updated state