Tensor ops
auto_circuit.utils.tensor_ops
Attributes
Classes
Functions
batch_answer_diff_percents
batch_answer_diff_percents(pred_vals: Tensor, target_vals: Tensor, batch: PromptPairBatch) -> Tensor
Find the percentage difference between the predicted logit differences and the target logit differences.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pred_vals |
Tensor
|
The predicted logit values or some tensor of the same shape. |
required |
target_vals |
Tensor
|
The target logit values or some tensor of the same shape. |
required |
batch |
PromptPairBatch
|
The batch of prompts and answers. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The percentage difference between the predicted logit differences and the target |
Tensor
|
logit differences. |
Source code in auto_circuit/utils/tensor_ops.py
batch_answer_diffs
batch_answer_diffs(vals: Tensor, batch: PromptPairBatch) -> Tensor
Find the difference between the average value of the correct answers and the average value of the wrong answers for each prompt in the batch.
If the batch answers are a List
, rather than a Tensor
, the function will be much
slower.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vals |
Tensor
|
The logits values or some tensor of the same shape. |
required |
batch |
PromptPairBatch
|
The batch of prompts and answers. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The difference between the average value of the correct answers and the average |
Tensor
|
value of the wrong answers for each prompt in the batch. |
Source code in auto_circuit/utils/tensor_ops.py
batch_avg_answer_diff
batch_avg_answer_diff(vals: Tensor, batch: PromptPairBatch) -> Tensor
Wrapper of batch_answer_diffs
that returns the mean of the differences.
Source code in auto_circuit/utils/tensor_ops.py
batch_avg_answer_val
batch_avg_answer_val(vals: Tensor, batch: PromptPairBatch, wrong_answer: bool = False) -> Tensor
Get the average value of the logits (or some function of them) for the correct answers in the batch.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vals |
Tensor
|
The logits values or some tensor of the same shape. |
required |
batch |
PromptPairBatch
|
The batch of prompts and answers. |
required |
wrong_answer |
bool
|
Whether to get the average value of the wrong answers instead of the correct answers. |
False
|
Returns:
Type | Description |
---|---|
Tensor
|
The average value of the logits for the correct answers in the batch. |
Source code in auto_circuit/utils/tensor_ops.py
correct_answer_greater_than_incorrect_proportion
correct_answer_greater_than_incorrect_proportion(logits: Tensor, batch: PromptPairBatch) -> Tensor
What proportion of the logits have the correct answer with a greater value than all the wrong answers?
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits |
Tensor
|
The logits values or some tensor of the same shape. |
required |
batch |
PromptPairBatch
|
The batch of prompts and answers. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The proportion of the logits that have the correct answer with a greater value |
Tensor
|
than all the wrong answers. |
Source code in auto_circuit/utils/tensor_ops.py
correct_answer_proportion
correct_answer_proportion(logits: Tensor, batch: PromptPairBatch) -> Tensor
What proportion of the logits have the correct answer as the maximum?
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logits |
Tensor
|
The logits values or some tensor of the same shape. |
required |
batch |
PromptPairBatch
|
The batch of prompts and answers. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The proportion of the logits that have the correct answer as the maximum. |
Source code in auto_circuit/utils/tensor_ops.py
desc_prune_scores
desc_prune_scores(prune_scores: PruneScores) -> Tensor
Flatten the prune scores into a single, 1-dimensional tensor and sort them in descending order.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prune_scores |
PruneScores
|
The prune scores to flatten and sort. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The flattened and sorted prune scores. |
Source code in auto_circuit/utils/tensor_ops.py
flat_prune_scores
flat_prune_scores(prune_scores: PruneScores) -> Tensor
Flatten the prune scores into a single, 1-dimensional tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prune_scores |
PruneScores
|
The prune scores to flatten. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The flattened prune scores. |
Source code in auto_circuit/utils/tensor_ops.py
multibatch_kl_div
Compute the average KL divergence between two sets of log probabilities.
Assumes the last dimension of input_logprobs
and target_logprobs
is the log
probability of each class. The other dimensions are batch dimensions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_logprobs |
Tensor
|
The input log probabilities. |
required |
target_logprobs |
Tensor
|
The target log probabilities. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The average KL divergence between the input and target log probabilities. |
Source code in auto_circuit/utils/tensor_ops.py
prune_scores_threshold
prune_scores_threshold(prune_scores: PruneScores | Tensor, edge_count: int) -> Tensor
Return the minimum absolute value of the top edge_count
prune scores.
Supports passing in a pre-sorted tensor of prune scores to avoid re-sorting.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prune_scores |
PruneScores | Tensor
|
The prune scores to threshold. |
required |
edge_count |
int
|
The number of edges that should be above the threshold. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The threshold value. |
Source code in auto_circuit/utils/tensor_ops.py
sample_hard_concrete
Sample from the hard concrete distribution (Louizos et al., 2017).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mask |
Tensor
|
The mask whose values parameterize the distribution. |
required |
batch_size |
int
|
The number of samples to draw. |
required |
mask_expanded |
bool
|
Whether the mask has a batch dimension at the start. |
False
|
Returns:
Type | Description |
---|---|
Tensor
|
A sample for each element in the mask for each batch element. The returned |
Tensor
|
tensor has shape |