Correct answer percent
auto_circuit.metrics.prune_metrics.correct_answer_percent
Attributes
Classes
Functions
measure_correct_ans_percent
measure_correct_ans_percent(model: PatchableModel, dataloader: PromptDataLoader, pruned_outs: CircuitOutputs, out_of_correct_and_incorrect_answers: bool = False) -> Measurements
Percentage of outputs for which the correct answer has the highest logit.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
Not used. |
required |
dataloader |
PromptDataLoader
|
The dataloader on which the |
required |
pruned_outs |
CircuitOutputs
|
The outputs of the ablated model for each circuit size. |
required |
out_of_correct_and_incorrect_answers |
bool
|
Whether to calculate the proportion of
prompts for which the correct answer has a higher logit than the incorrect
answers ( This is useful when you are particularly interested in the counterfactual comparison to the corrupt prompts. For example, in the Sports Player post Rajamanoharan et al. (2023) look at the proportion of prompts for which the correct sport has a greater logit than the two other sports. |
False
|
Note
This function assumes that each prompt in dataloader
has only one correct
answer. If not, an error will be raised.