gm.data.make_seq2seq_fields#
- gemma.gm.data.make_seq2seq_fields(
- prompt: kauldron.ktyping.array_type_meta.Int['prompt_len'],
- response: kauldron.ktyping.array_type_meta.Int['response_len'],
Create the model input, target and loss_mask.
From prompt and response token ids, generate the model input, target and loss_mask.
Example:
# Input: prompt = [10, 11, 12, 13], response = [20, 21, 1], # Here, response ends with EOS token. # Ouptut: out.input = [10, 11, 12, 13, 20, 21], out.target = [11, 12, 13, 20, 21, 1], out.target_mask = [ 0, 0, 0, 1, 1, 1],
Note
Input and target are the same sequence shifted by one token.
The last token from the target is truncated from the input (as there’s no target for it)
- Parameters:
prompt – The prompt tokens.
response – The response tokens.
- Returns:
The input, target and mask, all of length prompt_len + response_len - 1.