Kl div
auto_circuit.metrics.prune_metrics.kl_div
Attributes
Classes
Functions
measure_kl_div
measure_kl_div(model: PatchableModel, dataloader: PromptDataLoader, circuit_outs: CircuitOutputs, compare_to_clean: bool = True) -> Measurements
Average KL divergence between the full model and the circuits.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
The model on which |
required |
dataloader |
PromptDataLoader
|
The dataloader on which the |
required |
circuit_outs |
CircuitOutputs
|
The outputs of the ablated model for each circuit size. |
required |
compare_to_clean |
bool
|
Whether to compare the circuit output to the full model on the
clean ( |
True
|
Returns:
Type | Description |
---|---|
Measurements
|
A list of tuples, where the first element is the number of edges pruned and the second element is the average KL divergence for that number of edges. |