3) Other Features
Ablation types
- Resample (aka. patching)
- Zero
- Mean (calculated over a batch or PromptDataset)
See AblationType for more details.
Automatic KV caching
When tail_divergence
is True
,
load_datasets_from_json
automatically
computes the KV Cache for the common prefix of all of the prompts in the dataset and
removes the prefix from the prompts.
train_loader, test_loader = load_datasets_from_json(
model=model,
path=path,
device=device,
prepend_bos=True,
batch_size=16,
train_test_size=(128, 128),
tail_divergence=True,
)
The KV Caches are stored in the kv_cache
attribute of the
PromptDataloader
s. Pass the caches to the
patchable_model
function to use them
automatically.
model = patchable_model(
model,
factorized=True,
slice_output="last_seq",
separate_qkv=True,
kv_caches=(train_loader.kv_cache, test_loader.kv_cache),
device=device,
)
Automatically patch multiple circuits
To patch multiple circuits of increasing size (decreasing
PruneScores
), use the
run_circuits
function.
patch_edges: Dict[str, float] = {
"Resid Start->MLP 1": 1.0,
"MLP 1->MLP 2": 2.0,
"MLP 1->MLP 3": 1.0,
"MLP 2->A5.2.Q": 2.0,
"MLP 3->A5.2.Q": 1.0,
"A5.2->Resid End": 1.0,
}
ps: PruneScores = model.circuit_prune_scores(edge_dict=patch_edges)
circuit_outs: CircuitOutputs = run_circuits(
model=model,
dataloader=test_loader,
test_edge_counts=edge_counts_util(model.edges, prune_scores=ps),
prune_scores=ps,
patch_type=PatchType.EDGE_PATCH,
ablation_type=AblationType.RESAMPLE,
)