Patchable model
auto_circuit.utils.patchable_model
Attributes
Classes
PatchableModel
PatchableModel(nodes: Set[Node], srcs: Set[SrcNode], dests: Set[DestNode], edge_dict: Dict[int | None, List[Edge]], edges: Set[Edge], seq_dim: int, seq_len: Optional[int], wrappers: Set[PatchWrapperImpl], src_wrappers: Set[PatchWrapperImpl], dest_wrappers: Set[PatchWrapperImpl], out_slice: Tuple[slice | int, ...], is_factorized: bool, is_transformer: bool, separate_qkv: Optional[bool], kv_caches: Tuple[Optional[HookedTransformerKeyValueCache], ...], wrapped_model: Module)
Bases: Module
A model that can be ablated along individual edges in its computation graph.
This class has many of the same methods and attributes as TransformerLens'
HookedTransformer
s. These are simple wrappers which pass through to the
implementation in the wrapped model. These methods and attributes
are:
forward
run_with_cache
add_hook
reset_hooks
cfg
tokenizer
input_to_embed
blocks
to_tokens
to_str_tokens
to_string
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nodes |
Set[Node]
|
The set of all nodes in the computation graph. |
required |
srcs |
Set[SrcNode]
|
The set of all source nodes in the computation graph. |
required |
dests |
Set[DestNode]
|
The set of all destination nodes in the computation graph. |
required |
edge_dict |
Dict[int | None, List[Edge]]
|
A dictionary mapping sequence positions to the edges at that position. |
required |
edges |
Set[Edge]
|
The set of all edges in the computation graph. |
required |
seq_dim |
int
|
The sequence dimension of the model. This is the dimension on which new
inputs are concatenated. In transformers, this is |
required |
seq_len |
Optional[int]
|
The sequence length of the model inputs. If |
required |
wrappers |
Set[PatchWrapperImpl]
|
The set of all |
required |
src_wrappers |
Set[PatchWrapperImpl]
|
The set of all |
required |
dest_wrappers |
Set[PatchWrapperImpl]
|
The set of all |
required |
out_slice |
Tuple[slice | int, ...]
|
Specifies the index/slice of the output of the model to be considered for the task. |
required |
is_factorized |
bool
|
Whether the model is factorized, for Edge Ablation. Otherwise, only Node Ablation is possible. |
required |
is_transformer |
bool
|
Whether the model is a transformer. |
required |
separate_qkv |
Optional[bool]
|
Whether the model has separate query, key, and value inputs. Only used for transformers. |
required |
kv_caches |
Tuple[Optional[HookedTransformerKeyValueCache], ...]
|
A dictionary mapping batch sizes to the past key-value caches of the transformer. Only used for transformers. |
required |
wrapped_model |
Module
|
The model to wrap which is being made patchable. |
required |
Source code in auto_circuit/utils/patchable_model.py
Functions
circuit_prune_scores
circuit_prune_scores(edges: Optional[Collection[Edge | str]] = None, edge_dict: Optional[Dict[Edge, float] | Dict[str, float]] = None, bool: bool = False) -> PruneScores
Convert a set of edges to a corresponding
PruneScores
object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
edges |
Optional[Collection[Edge | str]]
|
The set of edges or edge names to convert to prune scores. |
None
|
bool |
bool
|
Whether to return the prune scores as boolean type tensors. |
False
|
Returns:
Type | Description |
---|---|
PruneScores
|
The prune scores corresponding to the set of edges. |
Source code in auto_circuit/utils/patchable_model.py
current_patch_masks_as_prune_scores
current_patch_masks_as_prune_scores() -> PruneScores
Convert the current patch masks to a corresponding
PruneScores
object.
Returns:
Type | Description |
---|---|
PruneScores
|
The prune scores corresponding to the current patch masks. |
Source code in auto_circuit/utils/patchable_model.py
forward
Wrapper around the forward method of the wrapped model. If kv_caches
is not
None
, the KV cache is passed to the wrapped model as a keyword argument.
Source code in auto_circuit/utils/patchable_model.py
input_to_embed
Wrapper around the input_to_embed
method of the wrapped TransformerLens
HookedTransformer
. If kv_caches
is not None
, the KV cache is passed to the
wrapped model as a keyword argument.
Source code in auto_circuit/utils/patchable_model.py
new_prune_scores
new_prune_scores(init_val: float = 0.0) -> PruneScores
A new PruneScores
instance with the same
keys and shapes as the current patch masks, initialized to init_val
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_val |
float
|
The initial value to set all the prune scores to. |
0.0
|
Returns:
Type | Description |
---|---|
PruneScores
|
A new |
Source code in auto_circuit/utils/patchable_model.py
run_with_cache
Wrapper around the run_with_cache
method of the wrapped TransformerLens
HookedTransformer
. If kv_caches
is not None
, the KV cache is passed to the
wrapped model as a keyword argument.