Skip to content

Sparse autoencoder

auto_circuit.model_utils.sparse_autoencoders.sparse_autoencoder

Attributes

Classes

SparseAutoencoder

SparseAutoencoder(wrapped_hook: HookPoint, n_latents: int, n_inputs: int)

Bases: Module

A Sparse Autoencoder wrapper module.

Takes some input, passes it through the autoencoder and passes the reconstructed input to the wrapped hook.

Implements

latents = ReLU(encoder(x - bias) + latent_bias) recons = decoder(latents) + bias

:param wrapped_hook: the wrapped transformer_lens hook that caches the SAE input :param n_latents: dimension of the autoencoder latent :param n_inputs: dimensionality of the input (e.g residual stream, MLP neurons)

Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
def __init__(self, wrapped_hook: HookPoint, n_latents: int, n_inputs: int) -> None:
    """
    :param wrapped_hook: the wrapped transformer_lens hook that caches the SAE input
    :param n_latents: dimension of the autoencoder latent
    :param n_inputs: dimensionality of the input (e.g residual stream, MLP neurons)
    """
    super().__init__()
    self.wrapped_hook: HookPoint = wrapped_hook
    self.latent_outs: HookPoint = HookPoint()
    # Weights start the same at each position. They're only different after pruning.
    self.init_params(n_latents, n_inputs)
    self.reset_activated_latents()
Functions
decode
decode(x: Tensor) -> Tensor

:param x: autoencoder x (shape: [..., [seq], n_latents]) :return: reconstructed data (shape: [..., [seq], n_inputs])

Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
def decode(self, x: t.Tensor) -> t.Tensor:
    """
    :param x: autoencoder x (shape: [..., [seq], n_latents])
    :return: reconstructed data (shape: [..., [seq], n_inputs])
    """
    ein_str = "... l, ... d l -> ... l d"
    latent_outs = self.latent_outs(einsum(x, self.decode_weight, ein_str))
    return latent_outs.sum(dim=-2) + self.bias
encode
encode(x: Tensor) -> Tensor

:param x: input data (shape: [..., [seq], n_inputs]) :return: autoencoder latents (shape: [..., [seq], n_latents])

Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
def encode(self, x: t.Tensor) -> t.Tensor:
    """
    :param x: input data (shape: [..., [seq], n_inputs])
    :return: autoencoder latents (shape: [..., [seq], n_latents])
    """
    encoded = einsum(x - self.bias, self.encode_weight, "... d, ... l d -> ... l")
    latents_pre_act = encoded + self.latent_bias
    return t.nn.functional.relu(latents_pre_act)
forward
forward(x: Tensor) -> Tensor

:param x: input data (shape: [..., n_inputs]) :return: reconstructed data (shape: [..., n_inputs])

Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
def forward(self, x: t.Tensor) -> t.Tensor:
    """
    :param x: input data (shape: [..., n_inputs])
    :return:  reconstructed data (shape: [..., n_inputs])
    """
    x = self.wrapped_hook(x)
    latents = self.encode(x)
    self.latent_total_act += latents.sum_to_size(self.latent_total_act.shape)
    recons = self.decode(latents)
    return recons

Functions