Types
auto_circuit.types
Attributes
AblationMeasurements
module-attribute
AblationMeasurements = Dict[AblationType, PruneMetricMeasurements]
A dictionary mapping from AblationTypes to
PruneMetricMeasurements.
AlgoMeasurements
module-attribute
AlgoMeasurements = Dict[AlgoKey, Measurements]
A dictionary mapping from AlgoKeys to
Measurements.
AlgoPruneScores
module-attribute
AlgoPruneScores = Dict[AlgoKey, PruneScores]
A dictionary mapping from AlgoKeys to
PruneScores.
AutoencoderInput
module-attribute
The activation in each layer that is replaced by an autoencoder reconstruction.
BatchOutputs
module-attribute
BatchOutputs = Dict[BatchKey, Tensor]
A dictionary mapping from BatchKeys to output
tensors.
CircuitOutputs
module-attribute
CircuitOutputs = Dict[int, BatchOutputs]
A dictionary mapping from the number of pruned edges to
BatchOutputs
MaskFn
module-attribute
Determines how mask values are used to ablate edges.
If None, the mask value is used directly to interpolate between the original and
ablated values. ie. 0.0 means the original value, 1.0 means the ablated value.
If "hard_concrete", the mask value parameterizes a "HardConcrete" distribution
(Louizos et al., 2017,
Cao et at., 2021)
which is sampled to interpolate between the original and ablated values. The
HardConcrete distribution allows us to optimize a continuous variable while still
allowing the mask to take values equal to 0.0 or 1.0. And the stochasticity helps to
reduce problems from vanishing gradients.
If "sigmoid", the mask value is passed through a sigmoid function and then used to
interpolate between the original and ablated values.
Measurements
module-attribute
List of X and Y measurements. X is often the number of edges in the circuit and Y is often some measure of faithfulness.
OutputSlice
module-attribute
The slice of the output that is considered for task evaluation. For example,
"last_seq" will consider the last token's output in transformer models.
PruneMetricKey
module-attribute
A string the uniquely identifies a
PruneMetric.
PruneMetricMeasurements
module-attribute
PruneMetricMeasurements = Dict[PruneMetricKey, TaskMeasurements]
A dictionary mapping from PruneMetricKeys to
TaskMeasurements.
PruneScores
module-attribute
TaskMeasurements
module-attribute
TaskMeasurements = Dict[TaskKey, AlgoMeasurements]
A dictionary mapping from TaskKeys to
AlgoMeasurements.
TaskPruneScores
module-attribute
TaskPruneScores = Dict[TaskKey, AlgoPruneScores]
A dictionary mapping from TaskKeys to
AlgoPruneScores.
TestEdges
module-attribute
TestEdges = EdgeCounts | List[int | float]
Determines the set of [number of edges to prune] to test. This value is used as a
parameter to edge_counts_util.
If a list of integers, then these are the edge counts that will be used. If a list of floats, then these proportions of the total number of edges will be used.
Classes
AblationType
Bases: Enum
Type of activation with which replace an original activation during a forward pass.
Attributes
BATCH_ALL_TOK_MEAN
class-attribute
instance-attribute
Compute the mean over all tokens in the current batch.
BATCH_TOKENWISE_MEAN
class-attribute
instance-attribute
Compute the token-wise mean over the current input batch.
RESAMPLE
class-attribute
instance-attribute
Use the corresponding activation from the forward pass of the corrupt input.
TOKENWISE_MEAN_CLEAN
class-attribute
instance-attribute
Compute the token-wise mean of the clean input over the entire dataset.
TOKENWISE_MEAN_CLEAN_AND_CORRUPT
class-attribute
instance-attribute
Compute the token-wise mean of the clean and corrupt inputs over the entire dataset.
TOKENWISE_MEAN_CORRUPT
class-attribute
instance-attribute
Compute the token-wise mean of the corrupt input over the entire dataset.
DestNode
dataclass
Edge
dataclass
A directed edge from a SrcNode to a
DestNode in the computational graph of the model
used for ablation.
And an optional sequence index that specifies the token position when the
PatchableModel has seq_len
not None.
Attributes
patch_idx
property
The index of the edge in the patch_mask or
PruneScores tensor of the dest node.
seq_idx
class-attribute
instance-attribute
The sequence index of the edge.
Functions
patch_mask
prune_score
prune_score(prune_scores: PruneScores) -> Tensor
The score of the edge in the given
PruneScores.
EdgeCounts
Bases: Enum
Special values for TestEdges that get computed
at runtime.
Attributes
GROUPS
class-attribute
instance-attribute
Group edges by PruneScores and cumulatively add
the number of edges in each group in descending order by score.
Node
dataclass
Node(name: str, module_name: str, layer: int, head_idx: Optional[int] = None, head_dim: Optional[int] = None, weight: Optional[str] = None, weight_head_dim: Optional[int] = None)
A node in the computational graph of the model used for ablation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name |
str
|
The name of the node. |
required |
module_name |
str
|
The name of the PyTorch module in the model that the node is in. Modules can have multiple nodes, for example, the multi-head attention module in a transformer model has a node for each head. |
required |
layer |
int
|
The layer of the model that the node is in. Transformer blocks count as 2 layers (one for the attention layer and one for the MLP layer) because we want to connect nodes in the attention layer to nodes in the subsequent MLP layer. |
required |
head_idx |
Optional[int]
|
The index of the head in the multi-head attention module that the node is in. |
None
|
head_dim |
Optional[int]
|
The dimension of the head in the multi-head attention layer that the node is in. |
None
|
weight |
Optional[str]
|
The name of the weight in the module that corresponds to the node. Not currently used, but could be used by a circuit finding algorithm. |
None
|
weight_head_dim |
Optional[int]
|
The dimension of the head in the weight tensor that corresponds to the node. Not currently used, but could be used by a circuit finding algorithm. |
None
|
Functions
module
module(model: Any) -> PatchWrapper
Get the PatchWrapper for this
node.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model |
Any
|
The model that the node is in. |
required |
Returns:
| Type | Description |
|---|---|
PatchWrapper
|
The |