Skip to content

Prepare model and data

import torch as t

from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.types import PruneScores
from auto_circuit.utils.graph_utils import patchable_model
from auto_circuit.utils.misc import repo_path_to_abs_path
from auto_circuit.visualize import draw_seq_graph

device = t.device("cuda" if t.cuda.is_available() else "cpu")
model = load_tl_model("gpt2", device)

path = repo_path_to_abs_path("datasets/ioi/ioi_vanilla_template_prompts.json")
train_loader, test_loader = load_datasets_from_json(
    model=model,
    path=path,
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
)

model = patchable_model(
    model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
)

Edge Attribution Patching circuit discovery

Mask gradients are a faster way to compute Edge Attribution Patching, see the reference documentation for more details.

attribution_scores: PruneScores = mask_gradient_prune_scores(
    model=model,
    dataloader=train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    mask_val=0.0,
)
AutoCircuit supports a range of circuit discovery methods.

For the full set of available methods, see the reference documentation.

Visualize the circuit

fig = draw_seq_graph(
    model, attribution_scores, 3.5, layer_spacing=True, orientation="v"
)
Blue edges represent positive contributions to the output, red edges represent negative contributions, and the thickness of the edge represents the magnitude of the contribution.

Token specific circuit discovery

AutoCircuit can construct a computation graph that differentiates between different token positions. All prompts in the dataset must have the same sequence length.

Get prompt sequence length

Set return_seq_length=True in load_datasets_from_json to return the sequence length of the prompt.

train_loader, test_loader = load_datasets_from_json(
    model=model,
    path=path,
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
    return_seq_length=True,
)

Create the computation graph with token specific edges

Set seq_length to the sequence length of the prompt in patchable_model.

model = patchable_model(
    model,
    factorized=True,
    slice_output="last_seq",
    seq_len=test_loader.seq_len,
    separate_qkv=True,
    device=device,
)

Edge Attribution Patching circuit discovery

attribution_scores: PruneScores = mask_gradient_prune_scores(
    model=model,
    dataloader=train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    mask_val=0.0,
)

Visualize the circuit

fig = draw_seq_graph(model, attribution_scores, 3.5, seq_labels=train_loader.seq_labels)