Edge attribution patching
auto_circuit.prune_algos.edge_attribution_patching
Attributes
Classes
Functions
edge_attribution_patching_prune_scores
edge_attribution_patching_prune_scores(model: PatchableModel, dataloader: PromptDataLoader, official_edges: Optional[Set[Edge]], answer_diff: bool = True) -> PruneScores
Prune scores by Edge Attribution patching.
This is an exact replication of the technique introduced in "Attribution Patching Outperforms Automated Circuit Discovery" (Syed et al. (2023)), as implemented in their codebase.
It is equivalent to
mask_gradient_prune_scores
with grad_function="logit"
and mask_val=0.0
. We verify that the output is
exactly the same in test_edge_attribution_patching.py
. This
implementation is much slower, so we don't use it in practice, but it's useful for
validating the correctness of the fast implementation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
The model to find the circuit for. |
required |
dataloader |
PromptDataLoader
|
The dataloader to use for input and ablation. |
required |
official_edges |
Optional[Set[Edge]]
|
Not used. |
required |
Returns:
Type | Description |
---|---|
PruneScores
|
An ordering of the edges by importance to the task. Importance is equal to the absolute value of the score assigned to the edge. |
Note
The implementation here uses clean_act - corrupt_act
, as described in the
paper, rather than corrupt_act - clean_act
, as in author's implementation. It
doesn't matter either way as we only consider the magnitude of the scores.
Source code in auto_circuit/prune_algos/edge_attribution_patching.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
|