peft.ModuleInterceptor

peft.ModuleInterceptor#

class gemma.peft.ModuleInterceptor(
replace_module_fn: Callable[[flax.linen.module.Module], flax.linen.module.Module],
)[source]

Bases: gemma.peft._interceptors.Interceptor

Interceptor that capture all modules to eventually replaces them.

For each modules, this interceptor call the replace_module_fn function which returns the module to use instead.

Example:

def _replace_dense_by_lora(module):
  if isinstance(module, nn.Dense):
    return peft.LoRADense(rank=3, wrapped=module)
  else:
    return module

# Within the context, the dense layers are replaced by their LoRA version.
with ModuleInterceptor(_replace_dense_by_lora):
  y = model(x)
replace_module_fn: Callable[[flax.linen.module.Module], flax.linen.module.Module]
interceptor(
next_fun,
args,
kwargs,
context: flax.linen.module.InterceptorContext,
)[source]

Returns the names of the methods to intercept.