Task projector
auto_circuit.model_utils.task_projectors.task_projector
This was a weird idea I has where you learn a projection at each layer of a transformer that tries to remove as many directions as possible.
Attributes
Classes
TaskProjector
TaskProjector(wrapped_hook: HookPoint, n_inputs: int, seq_idxs: Optional[List[int]] = None, mask_fn: MaskFn = None, layernorm: bool = False)
Bases: Module
Task Projector
Implements
latents = ReLU(encoder(x - bias) + latent_bias) recons = decoder(latents) + bias
:param wrapped_hook: the wrapped transformer_lens hook that caches the SAE input :param n_inputs: dimensionality of the input (e.g residual stream, MLP neurons)
Source code in auto_circuit/model_utils/task_projectors/task_projector.py
Functions
decode
:param x: rotated data (shape: [..., [seq], n_inputs]) :return: unrotated data (shape: [..., [seq], n_inputs])
encode
:param x: input data (shape: [..., [seq], n_inputs]) :return: projected rotated data (shape: [..., [seq], n_inputs])
Source code in auto_circuit/model_utils/task_projectors/task_projector.py
forward
:param x: input data (shape: [..., n_inputs]) :return: projected data (shape: [..., n_inputs])