gm.data.ContrastiveTask

gm.data.ContrastiveTask#

class gemma.gm.data.ContrastiveTask(*, in_prompt: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], in_chosen: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], in_rejected: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_tokens: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_targets: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_mask: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], tokenizer: gemma.gm.text._tokenizer.Tokenizer, max_length: int, truncate: bool = False, drop_inputs: bool = True)[source]

Bases: grain._src.core.transforms.Map

Creates the contrastive model inputs for DPO-like loss.

Input:

{
    'prompt': 'How much are 2+2 ?',
    'chosen': 'Yes, this is 4.',
    'rejected': 'Of course, 2+2 is 42.',
}

Output:

{
    'tokens': i32['2 max_length'],
    'mask': bool['2 max_length'],
}

In the output, [chosen, rejected] token ids are stacked (in that order).

in_prompt

Input key

Type:

Any

in_chosen

Input key

Type:

Any

in_rejected

Input key

Type:

Any

out_tokens

Output key (will be added to the example dict)

Type:

Any

out_mask

Output key (will be added to the example dict)

Type:

Any

tokenizer

The tokenizer to use.

Type:

gemma.gm.text._tokenizer.Tokenizer

max_length

The max length of the sequence (examples will be padded/truncated to this length).

Type:

int

truncate

Whether to truncate the sequence to the max length. If False, sequences longer than the max_length will raise an error.

Type:

bool

drop_inputs

If True, remove the input keys from the output.

Type:

bool

in_prompt: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
in_chosen: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
in_rejected: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
out_tokens: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
out_targets: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
out_mask: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
tokenizer: gemma.gm.text._tokenizer.Tokenizer
max_length: int
truncate: bool = False
drop_inputs: bool = True
map(element)[source]

Maps a single element.