Patch wrapper
auto_circuit.utils.patch_wrapper
Attributes
Classes
PatchWrapperImpl
PatchWrapperImpl(module_name: str, module: Module, head_dim: Optional[int] = None, seq_dim: Optional[int] = None, is_src: bool = False, src_idxs: Optional[slice] = None, is_dest: bool = False, patch_mask: Optional[Tensor] = None, in_srcs: Optional[slice] = None)
Bases: PatchWrapper
PyTorch module that wraps another module, a Node
in
the computation graph of the model. Implements the abstract
PatchWrapper
class, which exists to work around
circular import issues.
If the wrapped module is a SrcNode
, the tensor
self.curr_src_outs
(a single instance of which is shared by all PatchWrappers in
the model) is updated with the output of the wrapped module.
If the wrapped module is a DestNode
, the input to
the wrapped module is adjusted in order to interpolate the activations of the
incoming edges between the default activations (self.curr_src_outs
) and the
ablated activations (self.patch_src_outs
).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name |
str
|
Name of the wrapped module. |
required |
module |
Module
|
The module to wrap. |
required |
head_dim |
Optional[int]
|
The dimension along which to split the heads. In TransformerLens
|
None
|
seq_dim |
Optional[int]
|
The sequence dimension of the model. This is the dimension on which new
inputs are concatenated. In transformers, this is |
None
|
is_src |
bool
|
Whether the wrapped module is a |
False
|
src_idxs |
Optional[slice]
|
The slice of the list of indices of
|
None
|
is_dest |
bool
|
Whether the wrapped module is a
|
False
|
patch_mask |
Optional[Tensor]
|
The mask that interpolates between the default activations
( |
None
|
in_srcs |
Optional[slice]
|
The slice of the list of indices of
|
None
|
Source code in auto_circuit/utils/patch_wrapper.py
Functions
set_mask_batch_size
Set the batch size of the patch mask. Should only be used by context manager
set_mask_batch_size
The current primary use case is to collect gradients on the patch mask for each input in the batch.
Warning
This is an exmperimental feature that breaks some parts of the library and should be used with caution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size |
int | None
|
The batch size of the patch mask. |
required |