Tool Use#

Open in Colab

Demo to show how to use tool-use with Gemma library.

Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature.

!pip install -q gemma
# Common imports
import os
import datetime

# Gemma imports
from gemma import gm

By default, Jax does not utilize the full GPU memory, but this can be overwritten. See GPU memory allocation:

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Load the model and the params.

model = gm.nn.Gemma3_4B()

params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '<transport-type>://<backend-address>' (e.g., 'grpc://localhost'), but got 
INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'
INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'

Using existing tools#

If you’re familiar with the sampling tutorial, using tool-use differ in two ways:

  1. Using the gm.text.ToolSampler rather than the gm.text.ChatSampler.

  2. Passing the tools= you want to use to the sampler.

For example:

sampler = gm.text.ToolSampler(
    model=model,
    params=params,
    tools=[
        gm.tools.Calculator(),
        gm.tools.FileExplorer(),
    ],
    print_stream=True,
)

output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')
Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.
Let's start with S0 = 3.
S1 = cos(S0) * 2 = cos(3) * 2
S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2
S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2
S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2

I will use the calculator to compute these values.
{"tool_name": "calculator", "expression": "cos(3) * 2"}

[Tool result: -1.9799849932008908]

Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2
{"tool_name": "calculator", "expression": "-1.9799849932008908 * 2"}

[Tool result: -3.9599699864017817]

Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2
{"tool_name": "calculator", "expression": "cos(-3.9599699864017817) * 2"}

[Tool result: -1.3668134299076982]

Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2
{"tool_name": "calculator", "expression": "cos(-1.3668134299076982) * 2"}

[Tool result: 0.4051424976130353]

Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2
{"tool_name": "calculator", "expression": "cos(0.4051424976130353) * 2"}

[Tool result: 1.8380924822033438]

The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438

Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through sampler.turns property.

Creating your own tool#

To create your own tool, you can inherit from the gm.tools.Tool class. You should provide:

  • A description & example, so the model knows how to use your tool

  • Implement the call method. The call function can take arbitrary **kwargs, but the name of the args should match the ones defined in tool_kwargs and tool_kwargs_doc

class DateTime(gm.tools.Tool):
  """Tool to access the current date."""

  DESCRIPTION = 'Access the current date, time,...'
  EXAMPLE = gm.tools.Example(
      query='Which day of the week are we today ?',
      thought='The `datetime.strptime` uses %a for day of the week',
      tool_kwargs={'format': '%a'},
      tool_kwargs_doc={'format': '<ANY datetime.strptime expression>'},
      result='Sat',
      answer='Today is Saturday.',
  )

  def call(self, format: str) -> str:
    dt = datetime.datetime.now()
    return dt.strftime(format)

The tool can then be used in the sampler:

sampler = gm.text.ToolSampler(
    model=model,
    params=params,
    tools=[
        DateTime(),
    ],
    print_stream=True,
)

output = sampler.chat('Which date are we today ?')
Thought: I need to get the current date.
{"tool_name": "datetime", "format": "%Y-%m-%d"}

[Tool result: 2025-06-06]

Today is June 6th, 2025.

Next steps#

  • See our multimodal example to query the model with images.

  • See our finetuning example to train Gemma on your custom task.