Prepare model and data
import torch as t
from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.types import PruneScores
from auto_circuit.utils.graph_utils import patchable_model
from auto_circuit.utils.misc import repo_path_to_abs_path
from auto_circuit.visualize import draw_seq_graph
device = t.device("cuda" if t.cuda.is_available() else "cpu")
model = load_tl_model("gpt2", device)
path = repo_path_to_abs_path("datasets/ioi/ioi_vanilla_template_prompts.json")
train_loader, test_loader = load_datasets_from_json(
model=model,
path=path,
device=device,
prepend_bos=True,
batch_size=16,
train_test_size=(128, 128),
)
model = patchable_model(
model,
factorized=True,
slice_output="last_seq",
separate_qkv=True,
device=device,
)
Edge Attribution Patching circuit discovery
Mask gradients are a faster way to compute Edge Attribution Patching, see the reference documentation for more details.
attribution_scores: PruneScores = mask_gradient_prune_scores(
model=model,
dataloader=train_loader,
official_edges=None,
grad_function="logit",
answer_function="avg_diff",
mask_val=0.0,
)
For the full set of available methods, see the reference documentation.
Visualize the circuit
Blue edges represent positive contributions to the output, red edges represent negative contributions, and the thickness of the edge represents the magnitude of the contribution.Token specific circuit discovery
AutoCircuit can construct a computation graph that differentiates between different token positions. All prompts in the dataset must have the same sequence length.
Get prompt sequence length
Set return_seq_length=True
in
load_datasets_from_json to return the
sequence length of the prompt.
train_loader, test_loader = load_datasets_from_json(
model=model,
path=path,
device=device,
prepend_bos=True,
batch_size=16,
train_test_size=(128, 128),
return_seq_length=True,
)
Create the computation graph with token specific edges
Set seq_length
to the sequence length of the prompt in
patchable_model.
model = patchable_model(
model,
factorized=True,
slice_output="last_seq",
seq_len=test_loader.seq_len,
separate_qkv=True,
device=device,
)
Edge Attribution Patching circuit discovery
attribution_scores: PruneScores = mask_gradient_prune_scores(
model=model,
dataloader=train_loader,
official_edges=None,
grad_function="logit",
answer_function="avg_diff",
mask_val=0.0,
)