Answer diff
auto_circuit.metrics.prune_metrics.answer_diff
Attributes
Classes
Functions
measure_answer_diff
measure_answer_diff(model: PatchableModel, test_loader: PromptDataLoader, circuit_outs: CircuitOutputs, prob_func: Literal['log_softmax', 'softmax', 'logits'] = 'logits') -> Measurements
The average difference in the logits (or some function of them) between the correct answers and the incorrect answers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
Not used. |
required |
test_loader |
PromptDataLoader
|
The dataloader on which the |
required |
circuit_outs |
CircuitOutputs
|
The outputs of the ablated model for each circuit size. |
required |
prob_func |
Literal['log_softmax', 'softmax', 'logits']
|
The function to apply to the logits before calculating the answer difference. |
'logits'
|
Returns:
Type | Description |
---|---|
Measurements
|
A list of tuples, where the first element is the number of edges pruned and the second element is the average answer difference for that number of edges. |