gm.math.apply_rope

gm.math.apply_rope#

gemma.gm.math.apply_rope(
inputs: jax.Array,
positions: jax.Array,
*,
base_frequency: int,
scale_factor: float = 1.0,
rope_proportion: float = 1.0,
) jax.Array[source]

Applies RoPE.

Let B denote batch size, L denote sequence length, N denote number of heads, and H denote head dimension. Note that H must be divisible by 2.

Parameters:
  • inputs – Array of shape [B, L, N, H].

  • positions – Array of shape [B, L].

  • base_frequency – Base frequency used to compute rotations.

  • scale_factor – The scale factor used for positional interpolation, allowing an expansion of sequence length beyond the pre-trained context length.

  • rope_proportion – The proportion of the head dimension to apply RoPE to.

Returns:

Array of shape [B, L, N, H].