Sharding#

Open in Colab

Sharding for Gemma models. This example run inference on Gemma 27B, on a TPU v2, using 8 devices.

!pip install -q gemma
# Common imports
import jax
import jax.numpy as jnp

# Gemma imports
from gemma import gm
from kauldron import kd

For this colab, make sure to be connected to the TPU kernel by selecting Change runtime type > v2-8 TPU to access multiple accelerators. Jax should display multiple devices.

jax.device_count()
8

Load the model, and the params. Here we load the 27B model.

model = gm.nn.Gemma3_27B()

When restoring the weights, you can pass sharding= parameter to the gm.ckpts.load_params. Here we use a naive kd.sharding.FSDPSharding heuristic.

params = gm.ckpts.load_params(
    gm.ckpts.CheckpointPath.GEMMA3_27B_IT,
    sharding=kd.sharding.FSDPSharding(),
)

Sampling#

Test the sharding using the gm.text.ChatSampler.

sampler = gm.text.ChatSampler(model=model, params=params)

out = sampler.chat('Tell me an unknown interesting biology fact about the brain.')
print(out)
Okay, here's a fascinating and relatively unknown fact about the brain:

**Your brain actively "cleans itself" during sleep with a system called the glymphatic system, and this cleaning process is *much* more efficient when you sleep on your side.**

Here's the breakdown:

* **The Glymphatic System:** For a long time, it was thought the brain didn't have a traditional lymphatic system (which clears waste from the body).  But in 2012, researchers discovered the glymphatic system. It's essentially a brain-wide waste clearance system that uses cerebrospinal fluid (CSF) to flush out metabolic waste products that build up during waking hours – things like amyloid-beta, a protein associated with Alzheimer's disease.
* **How it Works:** CSF flows *along* arteries into the brain tissue and then drains out along veins. This flow is significantly enhanced during sleep.
* **Side Sleeping is Key:**  Research (particularly studies using MRI scans) has shown that sleeping on your side – *especially* the left side – is the most effective position for clearing waste from the brain. This is because the lateral position allows gravity to assist the flow of CSF and facilitates the clearance of interstitial fluid (the fluid between brain cells) and waste products.  Sleeping on your back is *less* effective, and sleeping on your stomach is the *least* effective.

**Why is this relatively unknown?** The glymphatic system is a relatively recent discovery, and research is still ongoing.  It's also a complex system to study.  

**Source/Further Reading:**

*   **Oregon Health & Science University (OHSU) - The Brain's Cleaning System:** [https://www.ohsu.edu/news/2013/08/brain-cleanses-itself-during-sleep](https://www.ohsu.edu/news/2013/08/brain-cleanses-itself-during-sleep)
*   **ScienceAlert - Scientists Discover How You Can Optimize Your Brain's Cleaning System While You Sleep:** [https://www.sciencealert.com/scientists-discover-how-you-can-optimize-your-brain-s-cleaning-system-while-you-sleep](https://www.sciencealert.com/scientists-discover-how-you-can

Even though we’ve learned something new, the model still halucinated urls, still showing limitations of current system.

Calling the model directly#

It’s also possbile to directly call the model. For this, the input has first to be manually encoded.

tokenizer = gm.text.Gemma3Tokenizer()
prompt = tokenizer.encode('My name is', add_bos=True)  # /!\ Don't forget to add the BOS token
prompt = jnp.asarray(prompt)

When using sharding, input also has to be sharded. Here, we have a single prompt, so we use kd.sharding.REPLICATED sharding so each device get a copy of the prompt.

During training, usually the prompts will be batched and padded, then sharded using kd.sharding.FIRST_DIM, so the prompts are distributed across each devices.

prompt = kd.sharding.with_sharding_constraint(prompt, kd.sharding.REPLICATED)
# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Sample a token from the predicted logits
next_token = jax.random.categorical(
    jax.random.key(1),
    out.logits
)
tokenizer.decode(next_token)
' Mary'

You can also display the next token probability.

tokenizer.plot_logits(out.logits)

Training#

To use sharding during training, simply set the sharding= attribute of the trainer, like:

trainer = kd.train.Trainer(
    ...,
    sharding=kd.sharding.ShardingStrategy(
        params=kd.sharding.FSDPSharding(),
    ),
    ...,
)

See a full example at: https://github.com/google-deepmind/gemma/tree/main/examples/sharding.py