4) How it Works
Prune Scores
The most important data structure to understand in AutoCircuit is the
PruneScores object. This type is a map:
module_names of the DestNodes
in the model and the values are tensors with entries corresponding to the attribution
scores for each edge that points to that node.
You can access the score for a particular Edge by indexing
into the tensor at index given by the patch_idx of
the edge.
Patch Masks
Each DestNode is wrapped by a
PatchWrapper that contains a
patch_mask Pytorch Parameter. This tensor corresponds exactly to the tensor in the
PruneScores object that is indexed by the
DestNode module_name.
The value of the patch_mask for each edge interpolates between the default value of
the edge in the current forward pass and the value of the edge in patch_src_outs when
the patch_mode context manager is active.
There are helper functions to access the current mask value for a particular edge:
For a more thorough explanation of how patching works, see the announcement post for this library.