Answer value
auto_circuit.metrics.prune_metrics.answer_value
Attributes
Classes
Functions
measure_answer_val
measure_answer_val(model: PatchableModel, test_loader: PromptDataLoader, circuit_outs: CircuitOutputs, prob_func: Literal['log_softmax', 'softmax', 'logits'] = 'logits', wrong_answer: bool = False) -> Measurements
The average value of the logits (or some function of them) for the correct answers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
Not used. |
required |
test_loader |
PromptDataLoader
|
The dataloader on which the |
required |
circuit_outs |
CircuitOutputs
|
The outputs of the ablated model for each circuit size. |
required |
prob_func |
Literal['log_softmax', 'softmax', 'logits']
|
The function to apply to the logits before calculating the answer value. |
'logits'
|
wrong_answer |
bool
|
Whether to calculate the value for the wrong answers instead of the correct answers. |
False
|
Returns:
Type | Description |
---|---|
Measurements
|
A list of tuples, where the first element is the number of edges pruned and the second element is the average answer value for that number of edges. |