Tasks
auto_circuit.tasks
Attributes
Classes
Task
dataclass
Task(key: TaskKey, name: str, batch_size: int | Tuple[int, int], batch_count: int | Tuple[int, int], token_circuit: bool, _model_def: str | Module, _dataset_name: str, factorized: bool = True, separate_qkv: bool = True, _true_edge_func: Optional[Callable[..., Set[Edge]]] = None, slice_output: OutputSlice = 'last_seq', autoencoder_input: Optional[AutoencoderInput] = None, autoencoder_max_latents: Optional[int] = None, autoencoder_pythia_size: Optional[str] = None, autoencoder_prune_with_corrupt: Optional[bool] = None, dtype: dtype = t.float32, __init_complete__: bool = False)
A task to be used in the auto-circuit experiments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
TaskKey
|
A unique identifier for the task. |
required |
name |
str
|
A human-readable name for the task, used in visualizations. |
required |
batch_size |
int | Tuple[int, int]
|
The batch size to use for training and testing. |
required |
batch_count |
int | Tuple[int, int]
|
The number of batches to use for training and testing. |
required |
token_circuit |
bool
|
Whether to patch different token positions separately ( |
required |
_model_def |
str | Module
|
The model to use for the task. If a string, the model will be loaded
from the |
required |
_dataset_name |
str
|
The dataset name to use for the task. The file
|
required |
factorized |
bool
|
Whether to use the factorized model and Edge Patching ( |
True
|
separate_qkv |
bool
|
Whether to have separate Q, K, and V input nodes. Outputs from attention heads are the same either way. |
True
|
_true_edge_func |
Optional[Callable[..., Set[Edge]]]
|
A function that returns the true edges for the task. |
None
|
slice_output |
OutputSlice
|
Specifies the index/slice of the output of the model to be
considered for the task. For example, |
'last_seq'
|
autoencoder_input |
Optional[AutoencoderInput]
|
If not |
None
|
autoencoder_max_latents |
Optional[int]
|
When loading a model with autoencoders enabled (by
setting |
None
|
autoencoder_pythia_size |
Optional[str]
|
The Pythia size to use for the autoencoder. |
None
|
autoencoder_prune_with_corrupt |
Optional[bool]
|
Whether to prune the autoencoder with corrupt data. |
None
|
dtype |
dtype
|
Sets the data type with which to load |
float32
|
__init_complete__ |
bool
|
Whether the task has been initialized. |
False
|