4) How it Works
Prune Scores
The most important data structure to understand in AutoCircuit is the
PruneScores
object. This type is a map:
module_name
s of the DestNode
s
in the model and the values are tensors with entries corresponding to the attribution
scores for each edge that points to that node.
You can access the score for a particular Edge
by indexing
into the tensor at index given by the patch_idx of
the edge.
Patch Masks
Each DestNode
is wrapped by a
PatchWrapper
that contains a
patch_mask
Pytorch Parameter
. This tensor corresponds exactly to the tensor in the
PruneScores
object that is indexed by the
DestNode
module_name
.
The value of the patch_mask
for each edge interpolates between the default value of
the edge in the current forward pass and the value of the edge in patch_src_outs
when
the patch_mode
context manager is active.
There are helper functions to access the current mask value for a particular edge:
For a more thorough explanation of how patching works, see the announcement post for this library.