gm.ckpts.AnchoredPolicyLoader#
- class gemma.gm.ckpts.AnchoredPolicyLoader(
- *,
- policy: kauldron.checkpoints.partial_loader.InitTransform,
- anchor: kauldron.checkpoints.partial_loader.InitTransform | None = None,
Bases:
kauldron.checkpoints.partial_loader.InitTransformLoader for
gm.nn.AnchoredPolicymodels.Loaded load policy and anchor separately by providing sub-transforms.
This assume the sub-loaders only overwrite the state.params without modifying the rest of the state.
- policy: kauldron.checkpoints.partial_loader.InitTransform
- anchor: kauldron.checkpoints.partial_loader.InitTransform | None = None
- transform(
- state: kauldron.train.train_step.TrainState,
Transform the state by updating it with pre-trained values.
Notes:
transform functions can modify the state values but should NOT modify its structure, shape or dtypes.
transform should correctly propagate the sharding information from the given state.
- Parameters:
state – The state object to transform
- Returns:
The updated state