gm.math.apply_rope#
- gemma.gm.math.apply_rope(
- inputs: jax.Array,
- positions: jax.Array,
- *,
- base_frequency: int,
- scale_factor: float = 1.0,
- rope_proportion: float = 1.0,
Applies RoPE.
Let B denote batch size, L denote sequence length, N denote number of heads, and H denote head dimension. Note that H must be divisible by 2.
- Parameters:
inputs – Array of shape [B, L, N, H].
positions – Array of shape [B, L].
base_frequency – Base frequency used to compute rotations.
scale_factor – The scale factor used for positional interpolation, allowing an expansion of sequence length beyond the pre-trained context length.
rope_proportion – The proportion of the head dimension to apply RoPE to.
- Returns:
Array of shape [B, L, N, H].