gm.evals.SamplerEvaluator

gm.evals.SamplerEvaluator#

class gemma.gm.evals.SamplerEvaluator(**kwargs)[source]

Bases: kauldron.evals.evaluators.EvaluatorBase

Sampling evaluator.

The evaluator expects as dataset containing a Seq2SeqTask transform.

cache_length

Cache length to use. This is the maximum number of tokens the conversation can have (prompts, answers, images for all turns). Setting this to a fixed value avoids re-compilation between turns.

Type:

int

max_new_tokens

Maximum number of new tokens to generate. In total, the model will process input_length + max_new_tokens.

Type:

int

pad_length

Pad length for the input. This is useful to ensure the prompt is always the same length during sampling, which can be helful to avoid re-compilation.

Type:

int | None

num_batches

Number of batches. If None, sample the entire dataset.

Type:

int | None

ds

Dataset to evaluate on. Note that the dataset must be unbatched and contain raw str fields.

Type:

kauldron.data.pipelines.Pipeline

model

The model to use.

Type:

flax.linen.module.Module

losses

Losses to compute. Losses and metrics can access the prediction text through the key: preds.text.

Type:

collections.abc.Mapping[str, kauldron.losses.base.Loss]

metrics

Metrics to compute. Losses and metrics can access the prediction text through the key: preds.text.

Type:

collections.abc.Mapping[str, kauldron.metrics.base.Metric]

summaries

Optional summaries to write.

Type:

collections.abc.Mapping[str, kauldron.metrics.base.Metric]

cache_length: int = 4096
max_new_tokens: int
pad_length: int | None = None
num_batches: int | None = None
cache: bool = False
ds: kauldron.data.pipelines.Pipeline = _FakeRootCfg('cfg.eval_ds')
model: flax.linen.module.Module = _FakeRootCfg('cfg.model')
losses: collections.abc.Mapping[str, kauldron.losses.base.Loss]
metrics: collections.abc.Mapping[str, kauldron.metrics.base.Metric]
summaries: collections.abc.Mapping[str, kauldron.metrics.base.Metric]
evaluate(
state: kauldron.train.train_step.TrainState,
step: int,
) Any[source]

Run this evaluator then write and optionally return the results.

property ds_iter: kauldron.data.data_utils.IterableDataset

Iterate over the examples.