Skip to content

Patch wrapper

auto_circuit.utils.patch_wrapper

Attributes

Classes

PatchWrapperImpl

PatchWrapperImpl(module_name: str, module: Module, head_dim: Optional[int] = None, seq_dim: Optional[int] = None, is_src: bool = False, src_idxs: Optional[slice] = None, is_dest: bool = False, patch_mask: Optional[Tensor] = None, in_srcs: Optional[slice] = None)

Bases: PatchWrapper

PyTorch module that wraps another module, a Node in the computation graph of the model. Implements the abstract PatchWrapper class, which exists to work around circular import issues.

If the wrapped module is a SrcNode, the tensor self.curr_src_outs (a single instance of which is shared by all PatchWrappers in the model) is updated with the output of the wrapped module.

If the wrapped module is a DestNode, the input to the wrapped module is adjusted in order to interpolate the activations of the incoming edges between the default activations (self.curr_src_outs) and the ablated activations (self.patch_src_outs).

Note

Most PatchWrappers are both SrcNodes and DestNodes.

Parameters:

Name Type Description Default
module_name str

Name of the wrapped module.

required
module Module

The module to wrap.

required
head_dim Optional[int]

The dimension along which to split the heads. In TransformerLens HookedTransformers this is 2 because the activations have shape [batch, seq_len, n_heads, head_dim].

None
seq_dim Optional[int]

The sequence dimension of the model. This is the dimension on which new inputs are concatenated. In transformers, this is 1 because the activations are of shape [batch_size, seq_len, hidden_dim].

None
is_src bool

Whether the wrapped module is a SrcNode.

False
src_idxs Optional[slice]

The slice of the list of indices of SrcNodes which output from this module. This is used to slice the shared curr_src_outs tensor when updating the activations of the current forward pass.

None
is_dest bool

Whether the wrapped module is a DestNode.

False
patch_mask Optional[Tensor]

The mask that interpolates between the default activations (curr_src_outs) and the ablation activations (patch_src_outs).

None
in_srcs Optional[slice]

The slice of the list of indices of SrcNodes which input to this module. This is used to slice the shared curr_src_outs tensor and the shared patch_src_outs tensor, when interpolating the activations of the incoming edges.

None
Source code in auto_circuit/utils/patch_wrapper.py
def __init__(
    self,
    module_name: str,
    module: t.nn.Module,
    head_dim: Optional[int] = None,
    seq_dim: Optional[int] = None,
    is_src: bool = False,
    src_idxs: Optional[slice] = None,
    is_dest: bool = False,
    patch_mask: Optional[t.Tensor] = None,
    in_srcs: Optional[slice] = None,
):
    super().__init__()
    self.module_name: str = module_name
    self.module: t.nn.Module = module
    self.head_dim: Optional[int] = head_dim
    self.seq_dim: Optional[int] = seq_dim
    self.curr_src_outs: Optional[t.Tensor] = None
    self.in_srcs: Optional[slice] = in_srcs

    self.is_src = is_src
    if self.is_src:
        assert src_idxs is not None
        self.src_idxs: slice = src_idxs

    self.is_dest = is_dest
    if self.is_dest:
        assert patch_mask is not None
        self.patch_mask: t.nn.Parameter = t.nn.Parameter(patch_mask)
        self.patch_src_outs: Optional[t.Tensor] = None
        self.mask_fn: MaskFn = None
        self.dropout_layer: t.nn.Module = t.nn.Dropout(p=0.0)
    self.patch_mode = False
    self.batch_size = None

    assert head_dim is None or seq_dim is None or head_dim > seq_dim
    dims = range(1, max(head_dim if head_dim else 2, seq_dim if seq_dim else 2))
    self.dims = " ".join(["seq" if i == seq_dim else f"d{i}" for i in dims])
Functions
set_mask_batch_size
set_mask_batch_size(batch_size: int | None)

Set the batch size of the patch mask. Should only be used by context manager set_mask_batch_size

The current primary use case is to collect gradients on the patch mask for each input in the batch.

Warning

This is an exmperimental feature that breaks some parts of the library and should be used with caution.

Parameters:

Name Type Description Default
batch_size int | None

The batch size of the patch mask.

required
Source code in auto_circuit/utils/patch_wrapper.py
def set_mask_batch_size(self, batch_size: int | None):
    """
    Set the batch size of the patch mask. Should only be used by context manager
    [`set_mask_batch_size`][auto_circuit.utils.graph_utils.set_mask_batch_size]

    The current primary use case is to collect gradients on the patch mask for
    each input in the batch.

    Warning:
        This is an exmperimental feature that breaks some parts of the library and
        should be used with caution.

    Args:
        batch_size: The batch size of the patch mask.
    """
    if batch_size is None and self.batch_size is None:
        return
    if batch_size is None:  # removing batch dim
        self.patch_mask = t.nn.Parameter(self.patch_mask[0].clone())
    elif self.batch_size is None:  # adding batch_dim
        self.patch_mask = t.nn.Parameter(
            self.patch_mask.repeat(batch_size, *((1,) * self.patch_mask.ndim))
        )
    elif self.batch_size != batch_size:  # modifying batch dim
        self.patch_mask = t.nn.Parameter(
            self.patch_mask[0]
            .clone()
            .repeat(batch_size, *((1,) * self.patch_mask.ndim))
        )
    self.batch_size = batch_size

Functions