Skip to content

Patchable model

auto_circuit.utils.patchable_model

Attributes

Classes

PatchableModel

PatchableModel(nodes: Set[Node], srcs: Set[SrcNode], dests: Set[DestNode], edge_dict: Dict[int | None, List[Edge]], edges: Set[Edge], seq_dim: int, seq_len: Optional[int], wrappers: Set[PatchWrapperImpl], src_wrappers: Set[PatchWrapperImpl], dest_wrappers: Set[PatchWrapperImpl], out_slice: Tuple[slice | int, ...], is_factorized: bool, is_transformer: bool, separate_qkv: Optional[bool], kv_caches: Tuple[Optional[HookedTransformerKeyValueCache], ...], wrapped_model: Module)

Bases: Module

A model that can be ablated along individual edges in its computation graph.

This class has many of the same methods and attributes as TransformerLens' HookedTransformers. These are simple wrappers which pass through to the implementation in the wrapped model. These methods and attributes are:

  • forward
  • run_with_cache
  • add_hook
  • reset_hooks
  • cfg
  • tokenizer
  • input_to_embed
  • blocks
  • to_tokens
  • to_str_tokens
  • to_string

Parameters:

Name Type Description Default
nodes Set[Node]

The set of all nodes in the computation graph.

required
srcs Set[SrcNode]

The set of all source nodes in the computation graph.

required
dests Set[DestNode]

The set of all destination nodes in the computation graph.

required
edge_dict Dict[int | None, List[Edge]]

A dictionary mapping sequence positions to the edges at that position.

required
edges Set[Edge]

The set of all edges in the computation graph.

required
seq_dim int

The sequence dimension of the model. This is the dimension on which new inputs are concatenated. In transformers, this is 1 because the activations are of shape [batch_size, seq_len, hidden_dim].

required
seq_len Optional[int]

The sequence length of the model inputs. If None, all token positions are simultaneously ablated.

required
wrappers Set[PatchWrapperImpl]

The set of all PatchWrappers in the model.

required
src_wrappers Set[PatchWrapperImpl]

The set of all PatchWrappers that are source nodes.

required
dest_wrappers Set[PatchWrapperImpl]

The set of all PatchWrappers that are destination nodes.

required
out_slice Tuple[slice | int, ...]

Specifies the index/slice of the output of the model to be considered for the task.

required
is_factorized bool

Whether the model is factorized, for Edge Ablation. Otherwise, only Node Ablation is possible.

required
is_transformer bool

Whether the model is a transformer.

required
separate_qkv Optional[bool]

Whether the model has separate query, key, and value inputs. Only used for transformers.

required
kv_caches Tuple[Optional[HookedTransformerKeyValueCache], ...]

A dictionary mapping batch sizes to the past key-value caches of the transformer. Only used for transformers.

required
wrapped_model Module

The model to wrap which is being made patchable.

required
Source code in auto_circuit/utils/patchable_model.py
def __init__(
    self,
    nodes: Set[Node],
    srcs: Set[SrcNode],
    dests: Set[DestNode],
    edge_dict: Dict[int | None, List[Edge]],
    edges: Set[Edge],
    seq_dim: int,
    seq_len: Optional[int],
    wrappers: Set[PatchWrapperImpl],
    src_wrappers: Set[PatchWrapperImpl],
    dest_wrappers: Set[PatchWrapperImpl],
    out_slice: Tuple[slice | int, ...],
    is_factorized: bool,
    is_transformer: bool,
    separate_qkv: Optional[bool],
    kv_caches: Tuple[Optional[HookedTransformerKeyValueCache], ...],
    wrapped_model: t.nn.Module,
) -> None:
    super().__init__()
    self.nodes = nodes
    self.srcs = srcs
    self.dests = dests
    self.edge_dict = edge_dict
    self.edges = edges
    self.n_edges = len(edges)
    self.edge_name_dict = defaultdict(dict)
    for edge in edges:
        self.edge_name_dict[edge.seq_idx][edge.name] = edge
    self.seq_dim = seq_dim
    self.seq_len = seq_len
    self.wrappers = wrappers
    self.src_wrappers = src_wrappers
    self.dest_wrappers = dest_wrappers
    self.patch_masks = {}
    for dest_wrapper in self.dest_wrappers:
        self.patch_masks[dest_wrapper.module_name] = dest_wrapper.patch_mask
    self.out_slice = out_slice
    self.is_factorized = is_factorized
    self.is_transformer = is_transformer
    if is_transformer:
        assert separate_qkv is not None
    self.separate_qkv = separate_qkv
    if all([kv_cache is None for kv_cache in kv_caches]) or len(kv_caches) == 0:
        self.kv_caches = None
    else:
        self.kv_caches = {}
        for kv_cache in kv_caches:
            if kv_cache is not None:
                batch_size = kv_cache.previous_attention_mask.shape[0]
                self.kv_caches[batch_size] = kv_cache
    self.wrapped_model = wrapped_model
Functions
circuit_prune_scores
circuit_prune_scores(edges: Optional[Collection[Edge | str]] = None, edge_dict: Optional[Dict[Edge, float] | Dict[str, float]] = None, bool: bool = False) -> PruneScores

Convert a set of edges to a corresponding PruneScores object.

Parameters:

Name Type Description Default
edges Optional[Collection[Edge | str]]

The set of edges or edge names to convert to prune scores.

None
bool bool

Whether to return the prune scores as boolean type tensors.

False

Returns:

Type Description
PruneScores

The prune scores corresponding to the set of edges.

Source code in auto_circuit/utils/patchable_model.py
def circuit_prune_scores(
    self,
    edges: Optional[Collection[Edge | str]] = None,
    edge_dict: Optional[Dict[Edge, float] | Dict[str, float]] = None,
    bool: bool = False,
) -> PruneScores:
    """
    Convert a set of edges to a corresponding
    [`PruneScores`][auto_circuit.types.PruneScores] object.

    Args:
        edges: The set of edges or edge names to convert to prune scores.
        bool: Whether to return the prune scores as boolean type tensors.

    Returns:
        The prune scores corresponding to the set of edges.
    """
    ps = self.new_prune_scores()
    assert not (edges is None and edge_dict is None), "Must specify edges"

    # TODO: Raise an error if one of the edge names doesn't exist.
    if edges is not None:
        for edge in self.edges:
            if edge in edges or edge.name in edges:
                ps[edge.dest.module_name][edge.patch_idx] = 1.0
    else:
        assert edge_dict is not None
        for e in self.edges:
            if e in edge_dict.keys():
                ps[e.dest.module_name][e.patch_idx] = edge_dict[e]  # type: ignore
            if e.name in edge_dict.keys():
                ps[e.dest.module_name][e.patch_idx] = edge_dict[e]  # type: ignore
    if bool:
        return dict([(mod, mask.bool()) for (mod, mask) in ps.items()])
    else:
        return ps
current_patch_masks_as_prune_scores
current_patch_masks_as_prune_scores() -> PruneScores

Convert the current patch masks to a corresponding PruneScores object.

Returns:

Type Description
PruneScores

The prune scores corresponding to the current patch masks.

Source code in auto_circuit/utils/patchable_model.py
def current_patch_masks_as_prune_scores(self) -> PruneScores:
    """
    Convert the current patch masks to a corresponding
    [`PruneScores`][auto_circuit.types.PruneScores] object.

    Returns:
        The prune scores corresponding to the current patch masks.
    """
    return dict([(mod, mask.data) for (mod, mask) in self.patch_masks.items()])
forward
forward(*args: Any, **kwargs: Any) -> Any

Wrapper around the forward method of the wrapped model. If kv_caches is not None, the KV cache is passed to the wrapped model as a keyword argument.

Source code in auto_circuit/utils/patchable_model.py
def forward(self, *args: Any, **kwargs: Any) -> Any:
    """
    Wrapper around the forward method of the wrapped model. If `kv_caches` is not
    `None`, the KV cache is passed to the wrapped model as a keyword argument.
    """
    if self.kv_caches is None or "past_kv_cache" in kwargs:
        return self.wrapped_model(*args, **kwargs)
    else:
        batch_size = args[0].shape[0]
        kv = self.kv_caches[batch_size]
        return self.wrapped_model(*args, past_kv_cache=kv, **kwargs)
input_to_embed
input_to_embed(*args: Any, **kwargs: Any) -> Any

Wrapper around the input_to_embed method of the wrapped TransformerLens HookedTransformer. If kv_caches is not None, the KV cache is passed to the wrapped model as a keyword argument.

Source code in auto_circuit/utils/patchable_model.py
def input_to_embed(self, *args: Any, **kwargs: Any) -> Any:
    """
    Wrapper around the `input_to_embed` method of the wrapped TransformerLens
    `HookedTransformer`. If `kv_caches` is not `None`, the KV cache is passed to the
    wrapped model as a keyword argument.
    """
    if self.kv_caches is None:
        return self.wrapped_model.input_to_embed(*args, **kwargs)
    else:
        batch_size = args[0].shape[0]
        kv = self.kv_caches[batch_size]
        return self.wrapped_model.input_to_embed(*args, past_kv_cache=kv, **kwargs)
new_prune_scores
new_prune_scores(init_val: float = 0.0) -> PruneScores

A new PruneScores instance with the same keys and shapes as the current patch masks, initialized to init_val.

Parameters:

Name Type Description Default
init_val float

The initial value to set all the prune scores to.

0.0

Returns:

Type Description
PruneScores

A new PruneScores instance.

Source code in auto_circuit/utils/patchable_model.py
def new_prune_scores(self, init_val: float = 0.0) -> PruneScores:
    """
    A new [`PruneScores`][auto_circuit.types.PruneScores] instance with the same
    keys and shapes as the current patch masks, initialized to `init_val`.

    Args:
        init_val: The initial value to set all the prune scores to.

    Returns:
        A new [`PruneScores`][auto_circuit.types.PruneScores] instance.
    """
    prune_scores: PruneScores = {}
    for (mod_name, mask) in self.patch_masks.items():
        prune_scores[mod_name] = t.full_like(mask.data, init_val)
    return prune_scores
run_with_cache
run_with_cache(*args: Any, **kwargs: Any) -> Any

Wrapper around the run_with_cache method of the wrapped TransformerLens HookedTransformer. If kv_caches is not None, the KV cache is passed to the wrapped model as a keyword argument.

Source code in auto_circuit/utils/patchable_model.py
def run_with_cache(self, *args: Any, **kwargs: Any) -> Any:
    """
    Wrapper around the `run_with_cache` method of the wrapped TransformerLens
    `HookedTransformer`. If `kv_caches` is not `None`, the KV cache is passed to the
    wrapped model as a keyword argument.
    """
    if self.kv_caches is None:
        return self.wrapped_model.run_with_cache(*args, **kwargs)
    else:
        batch_size = args[0].shape[0]
        kv = self.kv_caches[batch_size]
        return self.wrapped_model.run_with_cache(*args, past_kv_cache=kv, **kwargs)