Prune
auto_circuit.prune
Attributes
Classes
Functions
run_circuits
run_circuits(model: PatchableModel, dataloader: PromptDataLoader, test_edge_counts: List[int], prune_scores: PruneScores, patch_type: PatchType = PatchType.EDGE_PATCH, ablation_type: AblationType = AblationType.RESAMPLE, reverse_clean_corrupt: bool = False, render_graph: bool = False, render_score_threshold: bool = False, render_file_path: Optional[str] = None) -> CircuitOutputs
Run the model, pruning edges based on the given prune_scores
. Runs the model
over the given dataloader
for each test_edge_count
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
The model to run |
required |
dataloader |
PromptDataLoader
|
The dataloader to use for input and patches |
required |
test_edge_counts |
List[int]
|
The numbers of edges to prune. |
required |
prune_scores |
PruneScores
|
The scores that determine the ordering of edges for pruning |
required |
patch_type |
PatchType
|
Whether to patch the circuit or the complement. |
EDGE_PATCH
|
ablation_type |
AblationType
|
The type of ablation to use. |
RESAMPLE
|
reverse_clean_corrupt |
bool
|
Reverse clean and corrupt (for input and patches). |
False
|
render_graph |
bool
|
Whether to render the graph using |
False
|
render_score_threshold |
bool
|
Edge score threshold, if |
False
|
render_file_path |
Optional[str]
|
Path to save the rendered graph, if |
None
|
Returns:
Type | Description |
---|---|
CircuitOutputs
|
A dictionary mapping from the number of pruned edges to a
|