gm.data.ContrastiveTask#
- class gemma.gm.data.ContrastiveTask(*, in_prompt: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], in_chosen: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], in_rejected: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_tokens: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_targets: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_mask: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], tokenizer: gemma.gm.text._tokenizer.Tokenizer, max_length: int, truncate: bool = False, drop_inputs: bool = True)[source]
Bases:
grain._src.core.transforms.MapCreates the contrastive model inputs for DPO-like loss.
Input:
{ 'prompt': 'How much are 2+2 ?', 'chosen': 'Yes, this is 4.', 'rejected': 'Of course, 2+2 is 42.', }
Output:
{ 'tokens': i32['2 max_length'], 'mask': bool['2 max_length'], }
In the output, [chosen, rejected] token ids are stacked (in that order).
- in_prompt
Input key
- Type:
Any
- in_chosen
Input key
- Type:
Any
- in_rejected
Input key
- Type:
Any
- out_tokens
Output key (will be added to the example dict)
- Type:
Any
- out_mask
Output key (will be added to the example dict)
- Type:
Any
- tokenizer
The tokenizer to use.
- Type:
gemma.gm.text._tokenizer.Tokenizer
- max_length
The max length of the sequence (examples will be padded/truncated to this length).
- Type:
int
- truncate
Whether to truncate the sequence to the max length. If False, sequences longer than the max_length will raise an error.
- Type:
bool
- drop_inputs
If True, remove the input keys from the output.
- Type:
bool
- in_prompt: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
- in_chosen: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
- in_rejected: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
- out_tokens: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
- out_targets: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
- out_mask: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
- tokenizer: gemma.gm.text._tokenizer.Tokenizer
- max_length: int
- truncate: bool = False
- drop_inputs: bool = True
- map(element)[source]
Maps a single element.