peft.ModuleInterceptor#
- class gemma.peft.ModuleInterceptor(
- replace_module_fn: Callable[[flax.linen.module.Module], flax.linen.module.Module],
Bases:
gemma.peft._interceptors.InterceptorInterceptor that capture all modules to eventually replaces them.
For each modules, this interceptor call the replace_module_fn function which returns the module to use instead.
Example:
def _replace_dense_by_lora(module): if isinstance(module, nn.Dense): return peft.LoRADense(rank=3, wrapped=module) else: return module # Within the context, the dense layers are replaced by their LoRA version. with ModuleInterceptor(_replace_dense_by_lora): y = model(x)
- replace_module_fn: Callable[[flax.linen.module.Module], flax.linen.module.Module]
- interceptor(
- next_fun,
- args,
- kwargs,
- context: flax.linen.module.InterceptorContext,
Returns the names of the methods to intercept.