Skip to content

Kl div

auto_circuit.metrics.prune_metrics.kl_div

Attributes

Classes

Functions

measure_kl_div

measure_kl_div(model: PatchableModel, dataloader: PromptDataLoader, circuit_outs: CircuitOutputs, compare_to_clean: bool = True) -> Measurements

Average KL divergence between the full model and the circuits.

Parameters:

Name Type Description Default
model PatchableModel

The model on which circuit_outs was calculated.

required
dataloader PromptDataLoader

The dataloader on which the circuit_outs was calculated.

required
circuit_outs CircuitOutputs

The outputs of the ablated model for each circuit size.

required
compare_to_clean bool

Whether to compare the circuit output to the full model on the clean (True) or corrupt (False) prompt.

True

Returns:

Type Description
Measurements

A list of tuples, where the first element is the number of edges pruned and the second element is the average KL divergence for that number of edges.

Source code in auto_circuit/metrics/prune_metrics/kl_div.py
def measure_kl_div(
    model: PatchableModel,
    dataloader: PromptDataLoader,
    circuit_outs: CircuitOutputs,
    compare_to_clean: bool = True,
) -> Measurements:
    """
    Average KL divergence between the full model and the circuits.

    Args:
        model: The model on which `circuit_outs` was calculated.
        dataloader: The dataloader on which the `circuit_outs` was calculated.
        circuit_outs: The outputs of the ablated model for each circuit size.
        compare_to_clean: Whether to compare the circuit output to the full model on the
            clean (`True`) or corrupt (`False`) prompt.

    Returns:
        A list of tuples, where the first element is the number of edges pruned and the
            second element is the average KL divergence for that number of edges.
    """
    circuit_kl_divs: Measurements = []
    default_logprobs: Dict[BatchKey, t.Tensor] = {}
    with t.inference_mode():
        for batch in dataloader:
            default_batch = batch.clean if compare_to_clean else batch.corrupt
            logits = model(default_batch)[model.out_slice]
            default_logprobs[batch.key] = log_softmax(logits, dim=-1)

    for edge_count, circuit_out in (pruned_out_pbar := tqdm(circuit_outs.items())):
        pruned_out_pbar.set_description_str(f"KL Div for {edge_count} edges")
        circuit_logprob_list: List[t.Tensor] = []
        default_logprob_list: List[t.Tensor] = []
        for batch in dataloader:
            circuit_logprob_list.append(log_softmax(circuit_out[batch.key], dim=-1))
            default_logprob_list.append(default_logprobs[batch.key])
        kl = multibatch_kl_div(t.cat(circuit_logprob_list), t.cat(default_logprob_list))

        # Numerical errors can cause tiny negative values in KL divergence
        circuit_kl_divs.append((edge_count, max(kl.item(), 0)))
    return circuit_kl_divs