Mask gradient
auto_circuit.prune_algos.mask_gradient
Attributes
Classes
Functions
mask_gradient_prune_scores
mask_gradient_prune_scores(model: PatchableModel, dataloader: PromptDataLoader, official_edges: Optional[Set[Edge]], grad_function: Literal['logit', 'prob', 'logprob', 'logit_exp'], answer_function: Literal['avg_diff', 'avg_val', 'mse'], mask_val: Optional[float] = None, integrated_grad_samples: Optional[int] = None, ablation_type: AblationType = AblationType.RESAMPLE, clean_corrupt: Optional[Literal['clean', 'corrupt']] = 'corrupt') -> PruneScores
Prune scores equal to the gradient of the mask values that interpolates the edges between the clean activations and the ablated activations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
PatchableModel
|
The model to find the circuit for. |
required |
dataloader |
PromptDataLoader
|
The dataloader to use for input. |
required |
official_edges |
Optional[Set[Edge]]
|
Not used. |
required |
grad_function |
Literal['logit', 'prob', 'logprob', 'logit_exp']
|
Function to apply to the logits before taking the gradient. |
required |
answer_function |
Literal['avg_diff', 'avg_val', 'mse']
|
Loss function of the model output which the gradient is taken with respect to. |
required |
mask_val |
Optional[float]
|
Value of the mask to use for the forward pass. Cannot be used if
|
None
|
integrated_grad_samples |
Optional[int]
|
If not |
None
|
ablation_type |
AblationType
|
The type of ablation to perform. |
RESAMPLE
|
clean_corrupt |
Optional[Literal['clean', 'corrupt']]
|
Whether to use the clean or corrupt inputs to calculate the ablations. |
'corrupt'
|
Returns:
Type | Description |
---|---|
PruneScores
|
An ordering of the edges by importance to the task. Importance is equal to the absolute value of the score assigned to the edge. |
Note
When grad_function="logit"
and mask_val=0
this function is exactly
equivalent to
edge_attribution_patching_prune_scores
.
Source code in auto_circuit/prune_algos/mask_gradient.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
|