Types
auto_circuit.types
Attributes
AblationMeasurements
module-attribute
AblationMeasurements = Dict[AblationType, PruneMetricMeasurements]
A dictionary mapping from AblationType
s to
PruneMetricMeasurements
.
AlgoMeasurements
module-attribute
AlgoMeasurements = Dict[AlgoKey, Measurements]
A dictionary mapping from AlgoKey
s to
Measurements
.
AlgoPruneScores
module-attribute
AlgoPruneScores = Dict[AlgoKey, PruneScores]
A dictionary mapping from AlgoKey
s to
PruneScores
.
AutoencoderInput
module-attribute
The activation in each layer that is replaced by an autoencoder reconstruction.
BatchOutputs
module-attribute
BatchOutputs = Dict[BatchKey, Tensor]
A dictionary mapping from BatchKey
s to output
tensors.
CircuitOutputs
module-attribute
CircuitOutputs = Dict[int, BatchOutputs]
A dictionary mapping from the number of pruned edges to
BatchOutputs
MaskFn
module-attribute
Determines how mask values are used to ablate edges.
If None
, the mask value is used directly to interpolate between the original and
ablated values. ie. 0.0
means the original value, 1.0
means the ablated value.
If "hard_concrete"
, the mask value parameterizes a "HardConcrete" distribution
(Louizos et al., 2017,
Cao et at., 2021)
which is sampled to interpolate between the original and ablated values. The
HardConcrete distribution allows us to optimize a continuous variable while still
allowing the mask to take values equal to 0.0
or 1.0
. And the stochasticity helps to
reduce problems from vanishing gradients.
If "sigmoid"
, the mask value is passed through a sigmoid function and then used to
interpolate between the original and ablated values.
Measurements
module-attribute
List of X and Y measurements. X is often the number of edges in the circuit and Y is often some measure of faithfulness.
OutputSlice
module-attribute
The slice of the output that is considered for task evaluation. For example,
"last_seq"
will consider the last token's output in transformer models.
PruneMetricKey
module-attribute
A string the uniquely identifies a
PruneMetric
.
PruneMetricMeasurements
module-attribute
PruneMetricMeasurements = Dict[PruneMetricKey, TaskMeasurements]
A dictionary mapping from PruneMetricKey
s to
TaskMeasurements
.
PruneScores
module-attribute
TaskMeasurements
module-attribute
TaskMeasurements = Dict[TaskKey, AlgoMeasurements]
A dictionary mapping from TaskKey
s to
AlgoMeasurements
.
TaskPruneScores
module-attribute
TaskPruneScores = Dict[TaskKey, AlgoPruneScores]
A dictionary mapping from TaskKey
s to
AlgoPruneScores
.
TestEdges
module-attribute
TestEdges = EdgeCounts | List[int | float]
Determines the set of [number of edges to prune] to test. This value is used as a
parameter to edge_counts_util
.
If a list of integers, then these are the edge counts that will be used. If a list of floats, then these proportions of the total number of edges will be used.
Classes
AblationType
Bases: Enum
Type of activation with which replace an original activation during a forward pass.
Attributes
BATCH_ALL_TOK_MEAN
class-attribute
instance-attribute
Compute the mean over all tokens in the current batch.
BATCH_TOKENWISE_MEAN
class-attribute
instance-attribute
Compute the token-wise mean over the current input batch.
RESAMPLE
class-attribute
instance-attribute
Use the corresponding activation from the forward pass of the corrupt input.
TOKENWISE_MEAN_CLEAN
class-attribute
instance-attribute
Compute the token-wise mean of the clean input over the entire dataset.
TOKENWISE_MEAN_CLEAN_AND_CORRUPT
class-attribute
instance-attribute
Compute the token-wise mean of the clean and corrupt inputs over the entire dataset.
TOKENWISE_MEAN_CORRUPT
class-attribute
instance-attribute
Compute the token-wise mean of the corrupt input over the entire dataset.
DestNode
dataclass
Edge
dataclass
A directed edge from a SrcNode
to a
DestNode
in the computational graph of the model
used for ablation.
And an optional sequence index that specifies the token position when the
PatchableModel
has seq_len
not None
.
Attributes
patch_idx
property
The index of the edge in the patch_mask
or
PruneScores
tensor of the dest
node.
seq_idx
class-attribute
instance-attribute
The sequence index of the edge.
Functions
patch_mask
prune_score
prune_score(prune_scores: PruneScores) -> Tensor
The score of the edge in the given
PruneScores
.
EdgeCounts
Bases: Enum
Special values for TestEdges
that get computed
at runtime.
Attributes
GROUPS
class-attribute
instance-attribute
Group edges by PruneScores
and cumulatively add
the number of edges in each group in descending order by score.
Node
dataclass
Node(name: str, module_name: str, layer: int, head_idx: Optional[int] = None, head_dim: Optional[int] = None, weight: Optional[str] = None, weight_head_dim: Optional[int] = None)
A node in the computational graph of the model used for ablation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
The name of the node. |
required |
module_name |
str
|
The name of the PyTorch module in the model that the node is in. Modules can have multiple nodes, for example, the multi-head attention module in a transformer model has a node for each head. |
required |
layer |
int
|
The layer of the model that the node is in. Transformer blocks count as 2 layers (one for the attention layer and one for the MLP layer) because we want to connect nodes in the attention layer to nodes in the subsequent MLP layer. |
required |
head_idx |
Optional[int]
|
The index of the head in the multi-head attention module that the node is in. |
None
|
head_dim |
Optional[int]
|
The dimension of the head in the multi-head attention layer that the node is in. |
None
|
weight |
Optional[str]
|
The name of the weight in the module that corresponds to the node. Not currently used, but could be used by a circuit finding algorithm. |
None
|
weight_head_dim |
Optional[int]
|
The dimension of the head in the weight tensor that corresponds to the node. Not currently used, but could be used by a circuit finding algorithm. |
None
|
Functions
module
module(model: Any) -> PatchWrapper
Get the PatchWrapper
for this
node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Any
|
The model that the node is in. |
required |
Returns:
Type | Description |
---|---|
PatchWrapper
|
The |