gm.data.Seq2SeqTask

gm.data.Seq2SeqTask#

class gemma.gm.data.Seq2SeqTask(*, in_prompt: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], in_response: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_input: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_target: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], out_target_mask: typing.Annotated[typing.Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>], drop_inputs: bool = True, tokenizer: gemma.gm.text._tokenizer.Tokenizer, max_length: int, truncate: bool = False, sampling: bool = False)[source]

Bases: grain._src.core.transforms.Map

Sequence-to-sequence task.

This task will:

  • Format the prompt and response to match the expected dialog template (i.e. add the <start_of_turn>user, <end_of_turn>,…)

  • Tokenize the prompt and response.

  • Concatenate the input and response to create the model input and target (target is the input shifted by one token).

  • Create the loss mask (0 for prompt, 1 for response)

  • Pad/truncate the input and target to the max length.

Example:

# Input:
{
    'prompt': 'Hello! I would love to visit France.',
    'response': 'Bonjour ! J'adorerais visiter la France.',
}
# Ouptut:
{
    'input': i32['max_length'],
    'target': i32['max_length 1'],
    'target_mask': bool['max_length 1'],
}

Note

  • Input and target are the same sequence shifted by one token.

  • The last token from the target is truncated from the input (as there’s no target for it)

in_prompt

Input key

Type:

Any

in_response

Input key

Type:

Any

out_input

Output key (will be added to the example dict)

Type:

Any

out_target

Output key (will be added to the example dict)

Type:

Any

out_target_mask

Output key (will be added to the example dict)

Type:

Any

drop_inputs

If True, remove the input keys from the output.

Type:

bool

max_length

The max length of the sequence (examples will be padded/truncated to this length).

Type:

int

truncate

Whether to truncate the sequence to the max length. If False, sequences longer than the max_length will raise an error.

Type:

bool

sampling

If True, the dataset will yield the original prompt and response so they can be used inside gm.evals.SamplerEvaluator.

Type:

bool

in_prompt: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
in_response: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
out_input: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
out_target: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
out_target_mask: Annotated[Any, <kauldron.kontext.annotate._KeyToken object at 0x7001c239ecf0>]
drop_inputs: bool = True
tokenizer: gemma.gm.text._tokenizer.Tokenizer
max_length: int
truncate: bool = False
sampling: bool = False
map(element)[source]

Maps a single element.