Circuit probing
auto_circuit.prune_algos.circuit_probing
Attributes
Classes
Functions
circuit_probing_prune_scores
circuit_probing_prune_scores(model: PatchableModel, dataloader: PromptDataLoader, official_edges: Optional[Set[Edge]], learning_rate: float = 0.1, epochs: int = 20, regularize_lambda: float = 10, mask_fn: MaskFn = 'hard_concrete', dropout_p: float = 0.0, init_val: float = -init_mask_val, show_train_graph: bool = False, circuit_sizes: List[int | Literal['true_size']] = ['true_size'], tree_optimisation: bool = False, avoid_edges: Optional[Set[Edge]] = None, avoid_lambda: float = 1.0, faithfulness_target: SP_FAITHFULNESS_TARGET = 'kl_div', validation_dataloader: Optional[PromptDataLoader] = None) -> PruneScores
Wrapper of
Subnetwork Probing
that searches for circuits of different sizes and assigns scores to the edges
according to the size of the smallest circuit that they are part of. Smaller
circuits have higher scores because they contain more important edges. Edges not in
any circuit are assigned a score of 0
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
circuit_sizes |
List[int | Literal['true_size']]
|
List of circuit sizes to probe. If |
['true_size']
|