gm.losses.DpoLoss#
- class gemma.gm.losses.DpoLoss(*, step: kontext.Key = 'step', mask: Optional[kontext.Key] = None, weight: int | float | Schedule = 1.0, normalize_by: Literal['mask', 'values'] = 'mask', tau: float | typing.Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.1, label_smoothing: float | typing.Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.0, tokens: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__', sequence_mask: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__', policy_logits: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__', anchor_logits: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__')[source]
Bases:
kauldron.losses.base.LossDPO loss.
- tau
The temperature of the loss.
- Type:
float | Callable[[int], float | jaxtyping.Float[Array, ‘’] | jaxtyping.Float[ndarray, ‘’]]
- label_smoothing
The label smoothing to apply to the loss.
- Type:
float | Callable[[int], float | jaxtyping.Float[Array, ‘’] | jaxtyping.Float[ndarray, ‘’]]
- tokens
The key to the tokens to predict.
- Type:
Any
- sequence_mask
The key to the sequence mask.
- Type:
Any
- policy_logits
The key to the policy logits.
- Type:
Any
- anchor_logits
The key to the anchor logits.
- Type:
Any
- tau: float | Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.1
- label_smoothing: float | Callable[[int], float | jaxtyping.Float[Array, ''] | jaxtyping.Float[ndarray, '']] = 0.0
- tokens: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__'
- sequence_mask: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__'
- policy_logits: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__'
- anchor_logits: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>] = '__KEY_REQUIRED__'
- get_values(
- *,
- tokens: kauldron.ktyping.array_type_meta.Int['*B N L'],
- sequence_mask: kauldron.ktyping.array_type_meta.Bool['*B N L'],
- policy_logits: kauldron.ktyping.array_type_meta.Float['*B N L V'],
- anchor_logits: kauldron.ktyping.array_type_meta.Float['*B N L V'],
Computes the DPO loss.
- empty() kauldron.metrics.base.Metric.State[source]
- get_state(
- *args,
- mask: kauldron.ktyping.array_type_meta.Array['...'] | None = None,
- step: int | None = None,
- **kwargs,
Compute the loss state, and takes care of masking and loss-weight.
The Loss.State is AllReduceMean by default which keeps track of a single scalar loss value, but ensures correctly averaging even while using masks.
- Parameters:
*args – Positional arguments to be passed on to get_values.
mask – An optional mask to exclude some of the loss values from the total. The shape of this mask needs to be broadcastable to the shape of values returned from get_values. A value of 1 means that a value should be included (and 0 to exclude).
step – The current step to be used to compute the loss-weight if self.weight is set to a schedule. Otherwise step is ignored.
**kwargs – Keyword arguments to be passed on to get_values.
- Returns:
An instance of Loss.State (AllReduceMean by default) which keeps track of a single scalar loss value, but ensures correctly averaging even while using masks. This final loss value can be computed from this state by calling state.compute(). Optionally the state first can be reduced (to remove the device dimension after pmap) or merged with other (previous) loss states.