Data
auto_circuit.data
Attributes
Classes
PromptDataLoader
PromptDataLoader(prompt_dataset: Any, seq_len: Optional[int], diverge_idx: int, kv_cache: Optional[HookedTransformerKeyValueCache] = None, seq_labels: Optional[List[str]] = None, word_idxs: Dict[str, int] = {}, **kwargs: Any)
Bases: DataLoader[PromptPairBatch]
A DataLoader
for clean/corrupt prompt pairs with correct/incorrect answers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prompt_dataset |
Any
|
A |
required |
seq_len |
Optional[int]
|
The token length of the prompts (if fixed length). This prompt
length can be passed to |
required |
diverge_idx |
int
|
The index at which the clean and corrupt prompts diverge. (See
|
required |
kv_cache |
Optional[HookedTransformerKeyValueCache]
|
A cache of past key-value pairs for the transformer. Only used if
|
None
|
seq_labels |
Optional[List[str]]
|
A list of strings that label each token for fixed length
prompts. Used by
|
None
|
word_idxs |
Dict[str, int]
|
A dictionary with the token indexes of specific words. Used by official circuit functions. |
{}
|
kwargs |
Any
|
Additional arguments to pass to |
{}
|
Note
drop_last=True
is always passed to the parent DataLoader
constructor. So
all batches are always the same size. This simplifies the implementation of
several functions. For example, the kv_cache
only needs caches for a
single batch size.
Source code in auto_circuit/data.py
Attributes
diverge_idx
instance-attribute
diverge_idx = diverge_idx
The index at which the clean and corrupt prompts diverge. (See
load_datasets_from_json
for more
information.)
kv_cache
instance-attribute
kv_cache = kv_cache
A cache of past key-value pairs for the transformer. Only used if diverge_idx
is greater than 0. (See
load_datasets_from_json
for more
information.)
seq_labels
instance-attribute
seq_labels = seq_labels
A list of strings that label each token for fixed length prompts. Used by
draw_seq_graph
to label the circuit
diagram.
seq_len
instance-attribute
seq_len = seq_len
The token length of the prompts (if fixed length). This prompt length can be
passed to patchable_model
to enable patching specific token positions.
word_idxs
instance-attribute
word_idxs = word_idxs
A dictionary with the token indexes of specific words. Used by official circuit functions.
Functions
PromptDataset
PromptDataset(clean_prompts: List[Tensor] | Tensor, corrupt_prompts: List[Tensor] | Tensor, answers: List[Tensor], wrong_answers: List[Tensor])
Bases: Dataset
A dataset of clean/corrupt prompt pairs with correct/incorrect answers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clean_prompts |
List[Tensor] | Tensor
|
The 'clean' prompts. These are typically examples of the behavior we want to isolate where the model performs well. If a list, each element is a 1D prompt tensor. If a tensor, it should be 2D with shape (n_prompts, prompt_length). |
required |
corrupt_prompts |
List[Tensor] | Tensor
|
The 'corrupt' prompts. These are typically similar to the 'clean' prompts, but with some crucial difference that changes the model output. If a list, each element is a 1D prompt tensor. If a tensor, it should be 2D with shape (n_prompts, prompt_length). |
required |
answers |
List[Tensor]
|
A list of correct answers. Each element is a 1D tensor with the answer tokens. |
required |
wrong_answers |
List[Tensor]
|
A list of incorrect answers. Each element is a 1D tensor with the wrong answer tokens. |
required |
Source code in auto_circuit/data.py
Functions
PromptPair
A pair of clean and corrupt prompts with correct and incorrect answers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
clean |
The 'clean' prompt. This is typically an example of the behavior we want to isolate where the model performs well. |
required | |
corrupt |
The 'corrupt' prompt. This is typically similar to the 'clean' prompt, but with some crucial difference that changes the model output. |
required | |
answers |
The correct completions for the clean prompt. |
required | |
wrong_answers |
The incorrect completions for the clean prompt. |
required |
PromptPairBatch
A batch of prompt pairs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
A unique integer that identifies the batch. |
required | |
batch_diverge_idx |
The minimum index over all prompts at which the clean and
corrupt prompts diverge. This is used to automatically cache the key-value
activations for the common prefix of the prompts. See
|
required | |
clean |
The 'clean' prompts in a 2D tensor. These are typically examples of the behavior we want to isolate where the model performs well. |
required | |
corrupt |
The 'corrupt' prompts in a 2D tensor. These are typically similar to the 'clean' prompts, but with some crucial difference that changes the model output. |
required | |
answers |
The correct answers completions for the clean prompts.
If all prompts have the same number of answers, this is a 2D tensor.
If each prompt has a different number of answers, this is a list of 1D
tensors. This can make some methods such as
|
required | |
wrong_answers |
The incorrect answers. If each prompt has a different number of
wrong answers, this is a list of tensors.
If all prompts have the same number of wrong answers, this is a 2D tensor.
If each prompt has a different number of wrong answers, this is a list of 1D
tensors. This can make some methods such as
|
required |
Functions
load_datasets_from_json
load_datasets_from_json(model: Optional[Module], path: Path | List[Path], device: device, prepend_bos: bool = True, batch_size: int | Tuple[int, int] = 32, train_test_size: Tuple[int, int] = (128, 128), return_seq_length: bool = False, tail_divergence: bool = False, shuffle: bool = True, random_seed: int = 42, pad: bool = True) -> Tuple[PromptDataLoader, PromptDataLoader]
Load a dataset from a json file. The file should specify a list of dictionaries with keys "clean_prompt" and "corrupt_prompt".
JSON data format:
{
// Optional: used to label circuit visualization
"seq_labels": [str, ...],
// Optional: used by official circuit functions
"word_idxs": {
str: int,
...
},
// Required: the prompt pairs
"prompts": [
{
"clean": str | [[int, ...], ...],
"corrupt": str | [[int, ...], ...],
"answers": [str, ...] | [int, ...],
"wrong_answers": [str, ...] | [int, ...],
},
...
]
}
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Optional[Module]
|
Model to use for tokenization. If None, data must be pre-tokenized
( |
required |
path |
Path | List[Path]
|
Path to the json file with the dataset. If a list of paths is passed, the
first dataset is parsed in full and for the rest are the |
required |
device |
device
|
Device to load the data on. |
required |
prepend_bos |
bool
|
If True, prepend the |
True
|
batch_size |
int | Tuple[int, int]
|
The batch size for training and testing. If a single int is passed, the same batch size is used for both. |
32
|
return_seq_length |
bool
|
If |
False
|
tail_divergence |
bool
|
If all prompts share a common prefix, remove it and compute the
keys and values for each attention head on the prefix. A |
False
|
shuffle |
bool
|
If |
True
|
random_seed |
int
|
Seed for the random number generator. |
42
|
pad |
bool
|
If |
True
|
Note
shuffle
only shuffles the order of the prompts once at the beginning. The
order is preserved in the train and test loaders (shuffle=False
is always
passed to the PromptDataLoader
constructor). This makes it easier to ensure experiments are deterministic.
Source code in auto_circuit/data.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 |
|