Skip to content

Autoencoder transformer

auto_circuit.model_utils.sparse_autoencoders.autoencoder_transformer

A transformer model that patches in sparse autoencoder reconstructions at each layer. Work in progress. Error nodes not implemented.

Attributes

Classes

AutoencoderTransformer

AutoencoderTransformer(wrapped_model: Module, saes: List[SparseAutoencoder])

Bases: Module

Source code in auto_circuit/model_utils/sparse_autoencoders/autoencoder_transformer.py
def __init__(self, wrapped_model: t.nn.Module, saes: List[SparseAutoencoder]):
    super().__init__()
    self.sparse_autoencoders = saes

    if isinstance(wrapped_model, PatchableModel):
        self.wrapped_model = wrapped_model.wrapped_model
    else:
        self.wrapped_model = wrapped_model

Functions

factorized_dest_nodes

factorized_dest_nodes(model: AutoencoderTransformer, separate_qkv: bool) -> Set[DestNode]

Get the destination part of each edge in the factorized graph, grouped by layer. Graph is factorized following the Mathematical Framework paper.

Source code in auto_circuit/model_utils/sparse_autoencoders/autoencoder_transformer.py
def factorized_dest_nodes(
    model: AutoencoderTransformer, separate_qkv: bool
) -> Set[DestNode]:
    """Get the destination part of each edge in the factorized graph, grouped by layer.
    Graph is factorized following the Mathematical Framework paper."""
    if separate_qkv:
        assert model.cfg.use_split_qkv_input  # Separate Q, K, V input for each head
    else:
        assert model.cfg.use_attn_in
    if not model.cfg.attn_only:
        assert model.cfg.use_hook_mlp_in  # Get MLP input BEFORE layernorm
    layers = count(1)
    nodes = set()
    for block_idx in range(model.cfg.n_layers):
        layer = next(layers)
        for head_idx in range(model.cfg.n_heads):
            if separate_qkv:
                for letter in ["Q", "K", "V"]:
                    nodes.add(
                        DestNode(
                            name=f"A{block_idx}.{head_idx}.{letter}",
                            module_name=f"blocks.{block_idx}.hook_{letter.lower()}_input",
                            layer=layer,
                            head_dim=2,
                            head_idx=head_idx,
                            weight=f"blocks.{block_idx}.attn.W_{letter}",
                            weight_head_dim=0,
                        )
                    )
            else:
                nodes.add(
                    DestNode(
                        name=f"A{block_idx}.{head_idx}",
                        module_name=f"blocks.{block_idx}.hook_attn_in",
                        layer=layer,
                        head_dim=2,
                        head_idx=head_idx,
                        weight=f"blocks.{block_idx}.attn.W_QKV",
                        weight_head_dim=0,
                    )
                )
        nodes.add(
            DestNode(
                name=f"MLP {block_idx}",
                module_name=f"blocks.{block_idx}.hook_mlp_in",
                layer=layer if model.cfg.parallel_attn_mlp else next(layers),
                weight=f"blocks.{block_idx}.mlp.W_in",
            )
        )
    nodes.add(
        DestNode(
            name="Resid End",
            module_name=f"blocks.{model.cfg.n_layers - 1}.hook_resid_post",
            layer=next(layers),
            weight="unembed.W_U",
        )
    )
    return nodes

factorized_src_nodes

factorized_src_nodes(model: AutoencoderTransformer) -> Set[SrcNode]

Get the source part of each edge in the factorized graph, grouped by layer. Graph is factorized following the Mathematical Framework paper.

Source code in auto_circuit/model_utils/sparse_autoencoders/autoencoder_transformer.py
def factorized_src_nodes(model: AutoencoderTransformer) -> Set[SrcNode]:
    """Get the source part of each edge in the factorized graph, grouped by layer.
    Graph is factorized following the Mathematical Framework paper."""
    assert model.cfg.use_attn_result  # Get attention head outputs separately
    assert model.cfg.use_attn_in  # Get attention head inputs separately
    assert model.cfg.use_split_qkv_input  # Separate Q, K, V input for each head
    if not model.cfg.attn_only:
        assert model.cfg.use_hook_mlp_in  # Get MLP input BEFORE layernorm
    assert not model.cfg.attn_only

    layers, idxs = count(), count()
    nodes = set()
    nodes.add(
        SrcNode(
            name="Resid Start",
            module_name="blocks.0.hook_resid_pre",
            layer=next(layers),
            src_idx=next(idxs),
            weight="embed.W_E",
        )
    )

    for block_idx in range(model.cfg.n_layers):
        layer = next(layers)
        for head_idx in range(model.cfg.n_heads):
            nodes.add(
                SrcNode(
                    name=f"A{block_idx}.{head_idx}",
                    module_name=f"blocks.{block_idx}.attn.hook_result",
                    layer=layer,
                    src_idx=next(idxs),
                    head_dim=2,
                    head_idx=head_idx,
                    weight=f"blocks.{block_idx}.attn.W_O",
                    weight_head_dim=0,
                )
            )
        layer = layer if model.cfg.parallel_attn_mlp else next(layers)
        for latent_idx in range(model.blocks[block_idx].hook_mlp_out.n_latents):
            nodes.add(
                SrcNode(
                    name=f"MLP {block_idx} Latent {latent_idx}",
                    module_name=f"blocks.{block_idx}.hook_mlp_out.latent_outs",
                    layer=layer,
                    src_idx=next(idxs),
                    head_dim=2,
                    head_idx=latent_idx,
                    weight=f"blocks.{block_idx}.hook_mlp_out.decoder.weight",
                    weight_head_dim=0,
                )
            )
    return nodes

sae_model

sae_model(model: HookedTransformer, sae_input: AutoencoderInput, load_pretrained: bool, n_latents: Optional[int] = None, pythia_size: Optional[str] = None, new_instance: bool = True) -> AutoencoderTransformer

Inject SparseAutoencoder wrappers into a transformer model.

Source code in auto_circuit/model_utils/sparse_autoencoders/autoencoder_transformer.py
def sae_model(
    model: HookedTransformer,
    sae_input: AutoencoderInput,
    load_pretrained: bool,
    n_latents: Optional[int] = None,
    pythia_size: Optional[str] = None,
    new_instance: bool = True,
) -> AutoencoderTransformer:
    """
    Inject
    [`SparseAutoencoder`][auto_circuit.model_utils.sparse_autoencoders.sparse_autoencoder.SparseAutoencoder]
    wrappers into a transformer model.
    """
    if new_instance:
        model = deepcopy(model)
    sparse_autoencoders: List[SparseAutoencoder] = []
    for layer_idx in range(model.cfg.n_layers):
        if sae_input == "mlp_post_act":
            hook_point = model.blocks[layer_idx].mlp.hook_post
            hook_module = model.blocks[layer_idx].mlp
            hook_name = "hook_post"
        else:
            assert sae_input == "resid_delta_mlp"
            hook_point = model.blocks[layer_idx].hook_mlp_out
            hook_module = model.blocks[layer_idx]
            hook_name = "hook_mlp_out"
        if load_pretrained:
            assert pythia_size is not None
            sae = load_autoencoder(hook_point, model, layer_idx, sae_input, pythia_size)
        else:
            assert n_latents is not None
            sae = SparseAutoencoder(hook_point, n_latents, model.cfg.d_model)
        sae.to(model.cfg.device)
        setattr(hook_module, hook_name, sae)
        sparse_autoencoders.append(sae)
    return AutoencoderTransformer(model, sparse_autoencoders)