gm.data.make_seq2seq_fields

gm.data.make_seq2seq_fields#

gemma.gm.data.make_seq2seq_fields(
prompt: kauldron.ktyping.array_type_meta.Int['prompt_len'],
response: kauldron.ktyping.array_type_meta.Int['response_len'],
) gemma.gm.data._functional.Seq2SeqFields[source]

Create the model input, target and loss_mask.

From prompt and response token ids, generate the model input, target and loss_mask.

Example:

# Input:
prompt = [10, 11, 12, 13],
response = [20, 21, 1],  # Here, response ends with EOS token.

# Ouptut:
out.input =       [10, 11, 12, 13, 20, 21],
out.target =      [11, 12, 13, 20, 21,  1],
out.target_mask = [ 0,  0,  0,  1,  1,  1],

Note

  • Input and target are the same sequence shifted by one token.

  • The last token from the target is truncated from the input (as there’s no target for it)

Parameters:
  • prompt – The prompt tokens.

  • response – The response tokens.

Returns:

The input, target and mask, all of length prompt_len + response_len - 1.