Skip to content

Tasks

auto_circuit.tasks

Attributes

Classes

Task dataclass

Task(key: TaskKey, name: str, batch_size: int | Tuple[int, int], batch_count: int | Tuple[int, int], token_circuit: bool, _model_def: str | Module, _dataset_name: str, factorized: bool = True, separate_qkv: bool = True, _true_edge_func: Optional[Callable[..., Set[Edge]]] = None, slice_output: OutputSlice = 'last_seq', autoencoder_input: Optional[AutoencoderInput] = None, autoencoder_max_latents: Optional[int] = None, autoencoder_pythia_size: Optional[str] = None, autoencoder_prune_with_corrupt: Optional[bool] = None, dtype: dtype = t.float32, __init_complete__: bool = False)

A task to be used in the auto-circuit experiments.

Parameters:

Name Type Description Default
key TaskKey

A unique identifier for the task.

required
name str

A human-readable name for the task, used in visualizations.

required
batch_size int | Tuple[int, int]

The batch size to use for training and testing.

required
batch_count int | Tuple[int, int]

The number of batches to use for training and testing.

required
token_circuit bool

Whether to patch different token positions separately (True) or not (False).

required
_model_def str | Module

The model to use for the task. If a string, the model will be loaded from the transformer_lens library with the correct config.

required
_dataset_name str

The dataset name to use for the task. The file "datasets/{_dataset_name}.json" with be loaded using load_datasets_from_json.

required
factorized bool

Whether to use the factorized model and Edge Patching (True) or the residual model and Node Patching (False).

True
separate_qkv bool

Whether to have separate Q, K, and V input nodes. Outputs from attention heads are the same either way.

True
_true_edge_func Optional[Callable[..., Set[Edge]]]

A function that returns the true edges for the task.

None
slice_output OutputSlice

Specifies the index/slice of the output of the model to be considered for the task. For example, "last_seq" will consider the last token's output in transformer models.

'last_seq'
autoencoder_input Optional[AutoencoderInput]

If not None, the model will patch in autoencoder reconstructions at each layer of the model. This variable determines the activations passed to the autoencoder (eg. MLP output or residual stream).

None
autoencoder_max_latents Optional[int]

When loading a model with autoencoders enabled (by setting autoencoder_input to not None), this function uses _prune_latents_with_dataset to first prune the autoencoder latents. _prune_latents_with_datasets runs a batch of data through the model and prunes any latents that are not activated. This dramatically reduces the number of latent in the autoencoder (and therefore edges in the model), which is generally required to make circuit discovery feasible. However, there can still be too many feature remaining, so this parameter sets a cap such that we only keep the top autoencoder_max_latents features by activation.

None
autoencoder_pythia_size Optional[str]

The Pythia size to use for the autoencoder.

None
autoencoder_prune_with_corrupt Optional[bool]

Whether to prune the autoencoder with corrupt data.

None
dtype dtype

Sets the data type with which to load transformer_lens models.

float32
__init_complete__ bool

Whether the task has been initialized.

False

Functions