Measure prune metrics
auto_circuit.metrics.prune_metrics.measure_prune_metrics
Attributes
Classes
Functions
measure_prune_metrics
measure_prune_metrics(ablation_types: List[AblationType], metrics: List[PruneMetric], task_prune_scores: TaskPruneScores, patch_type: PatchType, reverse_clean_corrupt: bool = False, test_edge_counts: Optional[List[int]] = None) -> AblationMeasurements
Measure a set of circuit metrics for each
Task
s and each
PruneAlgos
in the given
task_prune_scores
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ablation_types |
List[AblationType]
|
The types of ablation to test. |
required |
metrics |
List[PruneMetric]
|
The metrics to measure. |
required |
task_prune_scores |
TaskPruneScores
|
The edge scores for each task and each algorithm. |
required |
patch_type |
PatchType
|
Whether to ablate the circuit or the complement. |
required |
reverse_clean_corrupt |
bool
|
Reverse clean and corrupt (for input and patches). |
False
|
test_edge_counts |
Optional[List[int]]
|
The set of [number of edges to prune] for each task and algorithm. |
None
|
Returns:
Type | Description |
---|---|
AblationMeasurements
|
A nested dictionary of measurements for each ablation type, metric, task, and algorithm (in that order). |
Source code in auto_circuit/metrics/prune_metrics/measure_prune_metrics.py
measurement_figs
measurement_figs(measurements: AblationMeasurements, auc_plots: bool = False) -> Tuple[Figure, ...]
Plot the measurements from
measure_prune_metrics
as a set of Plotly figures (one for each ablation type and metric).
Optionally include average Area Under the Curve (AUC) plots for each metric.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
measurements |
AblationMeasurements
|
The measurements to plot. |
required |
auc_plots |
bool
|
Whether to include the average AUC plots. |
False
|
Returns:
Type | Description |
---|---|
Tuple[Figure, ...]
|
A tuple of Plotly figures. |