Skip to content

Task projector

auto_circuit.model_utils.task_projectors.task_projector

This was a weird idea I has where you learn a projection at each layer of a transformer that tries to remove as many directions as possible.

Attributes

Classes

TaskProjector

TaskProjector(wrapped_hook: HookPoint, n_inputs: int, seq_idxs: Optional[List[int]] = None, mask_fn: MaskFn = None, layernorm: bool = False)

Bases: Module

Task Projector

Implements

latents = ReLU(encoder(x - bias) + latent_bias) recons = decoder(latents) + bias

:param wrapped_hook: the wrapped transformer_lens hook that caches the SAE input :param n_inputs: dimensionality of the input (e.g residual stream, MLP neurons)

Source code in auto_circuit/model_utils/task_projectors/task_projector.py
def __init__(
    self,
    wrapped_hook: HookPoint,
    n_inputs: int,
    seq_idxs: Optional[List[int]] = None,
    mask_fn: MaskFn = None,
    layernorm: bool = False,
) -> None:
    """
    :param wrapped_hook: the wrapped transformer_lens hook that caches the SAE input
    :param n_inputs: dimensionality of the input (e.g residual stream, MLP neurons)
    """
    super().__init__()
    self.wrapped_hook: HookPoint = wrapped_hook
    self.init_params(n_inputs, seq_idxs)
    self.mask_fn: MaskFn = mask_fn
    self.layernorm: bool = layernorm
Functions
decode
decode(x: Tensor) -> Tensor

:param x: rotated data (shape: [..., [seq], n_inputs]) :return: unrotated data (shape: [..., [seq], n_inputs])

Source code in auto_circuit/model_utils/task_projectors/task_projector.py
def decode(self, x: t.Tensor) -> t.Tensor:
    """
    :param x: rotated data (shape: [..., [seq], n_inputs])
    :return: unrotated data (shape: [..., [seq], n_inputs])
    """
    return self.rotation.inverse(x)
encode
encode(x: Tensor) -> Tensor

:param x: input data (shape: [..., [seq], n_inputs]) :return: projected rotated data (shape: [..., [seq], n_inputs])

Source code in auto_circuit/model_utils/task_projectors/task_projector.py
def encode(self, x: t.Tensor) -> t.Tensor:
    """
    :param x: input data (shape: [..., [seq], n_inputs])
    :return: projected rotated data (shape: [..., [seq], n_inputs])
    """
    rotated = self.rotation(x)
    einstr = []
    if self.mask_fn == "hard_concrete":
        mask_weights = sample_hard_concrete(self.dim_weights, x.size(0))
        einstr.append("batch")
    elif self.mask_fn == "sigmoid":
        mask_weights = t.sigmoid(self.dim_weights)
    else:
        assert self.mask_fn is None
        mask_weights = self.dim_weights

    if self.seq_len is not None:
        einstr.append("seq")

    einstr = " ".join(einstr) + " d, batch seq d -> batch seq d"
    masked_rotated = einsum(mask_weights, rotated, einstr)
    return masked_rotated + self.bias
forward
forward(x: Tensor) -> Tensor

:param x: input data (shape: [..., n_inputs]) :return: projected data (shape: [..., n_inputs])

Source code in auto_circuit/model_utils/task_projectors/task_projector.py
def forward(self, x: t.Tensor) -> t.Tensor:
    """
    :param x: input data (shape: [..., n_inputs])
    :return:  projected data (shape: [..., n_inputs])
    """
    # x = self.wrapped_hook(x)
    # projected_rotated = self.encode(x)
    # projected_unrotated = self.decode(projected_rotated)
    # return projected_unrotated
    head_size = None
    if head_dim := (x.ndim == 4):
        head_size = x.shape[2]
        x = x[:, :, 0]
    if self.layernorm:
        x = t.nn.functional.layer_norm(x, x.shape[-1:])
    if self.seq_idxs is not None:
        projected = (self.linear @ x[:, self.seq_idxs].unsqueeze(-1)).squeeze(-1)
        out = x.clone()
        out[:, self.seq_idxs] = projected + self.bias
    else:
        out = (self.linear @ x.unsqueeze(-1)).squeeze(-1) + self.bias
    if head_dim:
        assert head_size is not None
        out = out.unsqueeze(2).repeat(1, 1, head_size, 1)
    return out

Functions