Skip to content

Types

auto_circuit.types

Attributes

AblationMeasurements module-attribute

AblationMeasurements = Dict[AblationType, PruneMetricMeasurements]

A dictionary mapping from AblationTypes to PruneMetricMeasurements.

AlgoKey module-attribute

AlgoKey = str

A string that uniquely identifies a PruneAlgo.

AlgoMeasurements module-attribute

AlgoMeasurements = Dict[AlgoKey, Measurements]

A dictionary mapping from AlgoKeys to Measurements.

AlgoPruneScores module-attribute

AlgoPruneScores = Dict[AlgoKey, PruneScores]

A dictionary mapping from AlgoKeys to PruneScores.

AutoencoderInput module-attribute

AutoencoderInput = Literal['mlp_post_act', 'resid_delta_mlp', 'resid']

The activation in each layer that is replaced by an autoencoder reconstruction.

BatchOutputs module-attribute

BatchOutputs = Dict[BatchKey, Tensor]

A dictionary mapping from BatchKeys to output tensors.

CircuitOutputs module-attribute

CircuitOutputs = Dict[int, BatchOutputs]

A dictionary mapping from the number of pruned edges to BatchOutputs

MaskFn module-attribute

MaskFn = Optional[Literal['hard_concrete', 'sigmoid']]

Determines how mask values are used to ablate edges.

If None, the mask value is used directly to interpolate between the original and ablated values. ie. 0.0 means the original value, 1.0 means the ablated value.

If "hard_concrete", the mask value parameterizes a "HardConcrete" distribution (Louizos et al., 2017, Cao et at., 2021) which is sampled to interpolate between the original and ablated values. The HardConcrete distribution allows us to optimize a continuous variable while still allowing the mask to take values equal to 0.0 or 1.0. And the stochasticity helps to reduce problems from vanishing gradients.

If "sigmoid", the mask value is passed through a sigmoid function and then used to interpolate between the original and ablated values.

Measurements module-attribute

Measurements = List[Tuple[int | float, int | float]]

List of X and Y measurements. X is often the number of edges in the circuit and Y is often some measure of faithfulness.

OutputSlice module-attribute

OutputSlice = Optional[Literal['last_seq', 'not_first_seq']]

The slice of the output that is considered for task evaluation. For example, "last_seq" will consider the last token's output in transformer models.

PruneMetricKey module-attribute

PruneMetricKey = str

A string the uniquely identifies a PruneMetric.

PruneMetricMeasurements module-attribute

PruneMetricMeasurements = Dict[PruneMetricKey, TaskMeasurements]

A dictionary mapping from PruneMetricKeys to TaskMeasurements.

PruneScores module-attribute

PruneScores = Dict[str, Tensor]

Dictionary from module names of DestNodes to edge scores. The edge scores are stored as a tensor where each value corresponds to the score of an incoming Edge.

TaskKey module-attribute

TaskKey = str

A string that uniquely identifies a Task.

TaskMeasurements module-attribute

TaskMeasurements = Dict[TaskKey, AlgoMeasurements]

A dictionary mapping from TaskKeys to AlgoMeasurements.

TaskPruneScores module-attribute

TaskPruneScores = Dict[TaskKey, AlgoPruneScores]

A dictionary mapping from TaskKeys to AlgoPruneScores.

TestEdges module-attribute

TestEdges = EdgeCounts | List[int | float]

Determines the set of [number of edges to prune] to test. This value is used as a parameter to edge_counts_util.

If a list of integers, then these are the edge counts that will be used. If a list of floats, then these proportions of the total number of edges will be used.

Classes

AblationType

Bases: Enum

Type of activation with which replace an original activation during a forward pass.

Attributes
BATCH_ALL_TOK_MEAN class-attribute instance-attribute
BATCH_ALL_TOK_MEAN = 7

Compute the mean over all tokens in the current batch.

BATCH_TOKENWISE_MEAN class-attribute instance-attribute
BATCH_TOKENWISE_MEAN = 6

Compute the token-wise mean over the current input batch.

RESAMPLE class-attribute instance-attribute
RESAMPLE = 1

Use the corresponding activation from the forward pass of the corrupt input.

TOKENWISE_MEAN_CLEAN class-attribute instance-attribute
TOKENWISE_MEAN_CLEAN = 3

Compute the token-wise mean of the clean input over the entire dataset.

TOKENWISE_MEAN_CLEAN_AND_CORRUPT class-attribute instance-attribute
TOKENWISE_MEAN_CLEAN_AND_CORRUPT = 5

Compute the token-wise mean of the clean and corrupt inputs over the entire dataset.

TOKENWISE_MEAN_CORRUPT class-attribute instance-attribute
TOKENWISE_MEAN_CORRUPT = 4

Compute the token-wise mean of the corrupt input over the entire dataset.

ZERO class-attribute instance-attribute
ZERO = 2

Use a vector of zeros.

DestNode dataclass

DestNode(name: str, module_name: str, layer: int, head_idx: Optional[int] = None, head_dim: Optional[int] = None, weight: Optional[str] = None, weight_head_dim: Optional[int] = None, min_src_idx: int = 0)

Bases: Node

A node that is the destination of an edge.

Edge dataclass

Edge(src: SrcNode, dest: DestNode, seq_idx: Optional[int] = None)

A directed edge from a SrcNode to a DestNode in the computational graph of the model used for ablation.

And an optional sequence index that specifies the token position when the PatchableModel has seq_len not None.

Attributes
dest instance-attribute
dest: DestNode

The DestNode of the edge.

name property
name: str

The name of the edge. Equal to {src.name}->{dest.name}.

patch_idx property
patch_idx: Tuple[int, ...]

The index of the edge in the patch_mask or PruneScores tensor of the dest node.

seq_idx class-attribute instance-attribute
seq_idx: Optional[int] = None

The sequence index of the edge.

src instance-attribute
src: SrcNode

The SrcNode of the edge.

Functions
patch_mask
patch_mask(model: Any) -> Parameter

The patch_mask tensor of the dest node.

Source code in auto_circuit/types.py
def patch_mask(self, model: Any) -> t.nn.Parameter:
    """The `patch_mask` tensor of the `dest` node."""
    return self.dest.module(model).patch_mask
prune_score
prune_score(prune_scores: PruneScores) -> Tensor

The score of the edge in the given PruneScores.

Source code in auto_circuit/types.py
def prune_score(self, prune_scores: PruneScores) -> t.Tensor:
    """
    The score of the edge in the given
    [`PruneScores`][auto_circuit.types.PruneScores].
    """
    return prune_scores[self.dest.module_name][self.patch_idx]

EdgeCounts

Bases: Enum

Special values for TestEdges that get computed at runtime.

Attributes
ALL class-attribute instance-attribute
ALL = 1

Test 0, 1, 2, ..., n_edges edges.

GROUPS class-attribute instance-attribute
GROUPS = 3

Group edges by PruneScores and cumulatively add the number of edges in each group in descending order by score.

LOGARITHMIC class-attribute instance-attribute
LOGARITHMIC = 2

Test 0, 1, 2, ..., 10, 20, ..., 100, 200, ..., 1000, 2000, ... edges.

Node dataclass

Node(name: str, module_name: str, layer: int, head_idx: Optional[int] = None, head_dim: Optional[int] = None, weight: Optional[str] = None, weight_head_dim: Optional[int] = None)

A node in the computational graph of the model used for ablation.

Parameters:

Name Type Description Default
name str

The name of the node.

required
module_name str

The name of the PyTorch module in the model that the node is in. Modules can have multiple nodes, for example, the multi-head attention module in a transformer model has a node for each head.

required
layer int

The layer of the model that the node is in. Transformer blocks count as 2 layers (one for the attention layer and one for the MLP layer) because we want to connect nodes in the attention layer to nodes in the subsequent MLP layer.

required
head_idx Optional[int]

The index of the head in the multi-head attention module that the node is in.

None
head_dim Optional[int]

The dimension of the head in the multi-head attention layer that the node is in.

None
weight Optional[str]

The name of the weight in the module that corresponds to the node. Not currently used, but could be used by a circuit finding algorithm.

None
weight_head_dim Optional[int]

The dimension of the head in the weight tensor that corresponds to the node. Not currently used, but could be used by a circuit finding algorithm.

None
Functions
module
module(model: Any) -> PatchWrapper

Get the PatchWrapper for this node.

Parameters:

Name Type Description Default
model Any

The model that the node is in.

required

Returns:

Type Description
PatchWrapper

The PatchWrapper for this node.

Source code in auto_circuit/types.py
def module(self, model: Any) -> PatchWrapper:
    """
    Get the [`PatchWrapper`][auto_circuit.utils.patch_wrapper.PatchWrapper] for this
    node.

    Args:
        model: The model that the node is in.

    Returns:
        The `PatchWrapper` for this node.
    """
    patch_wrapper = module_by_name(model, self.module_name)
    assert isinstance(patch_wrapper, PatchWrapper)
    return patch_wrapper

PatchType

Bases: Enum

Whether to patch the edges in the circuit or the complement of the circuit.

Attributes
EDGE_PATCH class-attribute instance-attribute
EDGE_PATCH = 1

Patch the edges in the circuit.

TREE_PATCH class-attribute instance-attribute
TREE_PATCH = 2

Patch the edges not in the circuit.

PatchWrapper

PatchWrapper()

Bases: Module, ABC

Abstract class for a wrapper around a module that can be patched.

Source code in auto_circuit/types.py
def __init__(self):
    super().__init__()

SrcNode dataclass

SrcNode(name: str, module_name: str, layer: int, head_idx: Optional[int] = None, head_dim: Optional[int] = None, weight: Optional[str] = None, weight_head_dim: Optional[int] = None, src_idx: int = 0)

Bases: Node

A node that is the source of an edge.

Functions