Skip to content

3) Other Features

Ablation types

  • Resample (aka. patching)
  • Zero
  • Mean (calculated over a batch or PromptDataset)

See AblationType for more details.

ablations = src_ablations(model, test_loader, AblationType.RESAMPLE)

Automatic KV caching

When tail_divergence is True, load_datasets_from_json automatically computes the KV Cache for the common prefix of all of the prompts in the dataset and removes the prefix from the prompts.

train_loader, test_loader = load_datasets_from_json(
    model=model,
    path=path,
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
    tail_divergence=True,
)

The KV Caches are stored in the kv_cache attribute of the PromptDataloaders. Pass the caches to the patchable_model function to use them automatically.

model = patchable_model(
    model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    kv_caches=(train_loader.kv_cache, test_loader.kv_cache),
    device=device,
)

Automatically patch multiple circuits

To patch multiple circuits of increasing size (decreasing PruneScores), use the run_circuits function.

patch_edges: Dict[str, float] = {
    "Resid Start->MLP 1": 1.0,
    "MLP 1->MLP 2": 2.0,
    "MLP 1->MLP 3": 1.0,
    "MLP 2->A5.2.Q": 2.0,
    "MLP 3->A5.2.Q": 1.0,
    "A5.2->Resid End": 1.0,
}
ps: PruneScores = model.circuit_prune_scores(edge_dict=patch_edges)

circuit_outs: CircuitOutputs = run_circuits(
    model=model,
    dataloader=test_loader,
    test_edge_counts=edge_counts_util(model.edges, prune_scores=ps),
    prune_scores=ps,
    patch_type=PatchType.EDGE_PATCH,
    ablation_type=AblationType.RESAMPLE,
)

Measure circuit metrics

kl_divs = measure_kl_div(model, test_loader, circuit_outs)
For a full list of metrics, see the reference documentation.