Sparse autoencoder
auto_circuit.model_utils.sparse_autoencoders.sparse_autoencoder
Attributes
Classes
SparseAutoencoder
Bases: Module
A Sparse Autoencoder wrapper module.
Takes some input, passes it through the autoencoder and passes the reconstructed input to the wrapped hook.
Implements
latents = ReLU(encoder(x - bias) + latent_bias) recons = decoder(latents) + bias
:param wrapped_hook: the wrapped transformer_lens hook that caches the SAE input :param n_latents: dimension of the autoencoder latent :param n_inputs: dimensionality of the input (e.g residual stream, MLP neurons)
Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
Functions
decode
:param x: autoencoder x (shape: [..., [seq], n_latents]) :return: reconstructed data (shape: [..., [seq], n_inputs])
Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
encode
:param x: input data (shape: [..., [seq], n_inputs]) :return: autoencoder latents (shape: [..., [seq], n_latents])
Source code in auto_circuit/model_utils/sparse_autoencoders/sparse_autoencoder.py
forward
:param x: input data (shape: [..., n_inputs]) :return: reconstructed data (shape: [..., n_inputs])