Answer diff percent
auto_circuit.metrics.prune_metrics.answer_diff_percent
Attributes
Classes
Functions
answer_diff_percent
answer_diff_percent(model: PatchableModel, test_loader: PromptDataLoader, circuit_outs: CircuitOutputs, prob_func: Literal['log_softmax', 'softmax', 'logits'] = 'logits', diff_of_means: bool = True) -> Tuple[Measurements, Measurements, List[Tuple[int, Tensor]]]
The average percentage of the difference in the logits (or some function of them) between the correct answers and the incorrect answers in the full model that is recovered by the circuit.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
The model on which |
required |
test_loader |
PromptDataLoader
|
The dataloader on which the |
required |
circuit_outs |
CircuitOutputs
|
The outputs of the ablated model for each circuit size. |
required |
prob_func |
Literal['log_softmax', 'softmax', 'logits']
|
The function to apply to the logits before calculating the answer difference. |
'logits'
|
diff_of_means |
bool
|
Whether to calculate the difference of means ( |
True
|
Returns:
Type | Description |
---|---|
Tuple[Measurements, Measurements, List[Tuple[int, Tensor]]]
|
A tuple of three elements:
|
Source code in auto_circuit/metrics/prune_metrics/answer_diff_percent.py
measure_answer_diff_percent
measure_answer_diff_percent(model: PatchableModel, test_loader: PromptDataLoader, circuit_outs: CircuitOutputs, prob_func: Literal['log_softmax', 'softmax', 'logits'] = 'logits', diff_of_means: bool = True) -> Measurements
Wrapper of
answer_diff_percent
that returns only the average answer difference
percentage (the first element of the tuple).