Skip to content

Tracr model utils

auto_circuit.model_utils.tracr_model_utils

Attributes

TRACR_TASK_KEY module-attribute

TRACR_TASK_KEY = Literal['reverse', 'xproportion']

Identifier of a Tracr model. Currently supported models:

  • "reverse"
  • "xproportion"

Functions

get_tracr_model

get_tracr_model(tracr_task_key: TRACR_TASK_KEY, device: str) -> Tuple[HookedTransformer, AssembledTransformerModel]

Load the weights of a Tracr model and convert it to a HookedTransformer model.

Adapted from Neel Nanda's TransformerLens port of tracr.

Parameters:

Name Type Description Default
tracr_task_key TRACR_TASK_KEY

Identifier of the Tracr model.

required
device str

Device to load the model on.

required

Returns:

Type Description
Tuple[HookedTransformer, AssembledTransformerModel]

A tuple of the HookedTransformer model and the original AssembledTransformerModel.

Source code in auto_circuit/model_utils/tracr_model_utils.py
def get_tracr_model(
    tracr_task_key: TRACR_TASK_KEY, device: str
) -> Tuple[HookedTransformer, AssembledTransformerModel]:
    """
    Load the weights of a Tracr model and convert it to a HookedTransformer model.

    Adapted from Neel Nanda's TransformerLens port of tracr.

    Args:
        tracr_task_key: Identifier of the Tracr model.
        device: Device to load the model on.

    Returns:
        A tuple of the HookedTransformer model and the original
            AssembledTransformerModel.
    """

    def make_length():
        all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
        return rasp.SelectorWidth(all_true_selector)

    if tracr_task_key == "reverse":
        length = make_length()  # `length` is not a primitive in our implementation.
        opp_index = length - rasp.indices - 1
        flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
        reverse = rasp.Aggregate(flip, rasp.tokens)
        model = compiling.compile_rasp_to_model(
            reverse,
            vocab=REVERSE_VOCAB,
            max_seq_len=MAX_SEQ_LEN,
            compiler_bos=BOS,
        )
    elif tracr_task_key == "xproportion":
        model = compiling.compile_rasp_to_model(
            make_frac_prevs(rasp.tokens == "x"),
            vocab=XPROPORTION_VOCAB,
            max_seq_len=MAX_SEQ_LEN,
            compiler_bos=BOS,
        )
    else:
        raise ValueError(f"Unknown task {tracr_task_key}")

    # Extract the model config from the Tracr model, and create a blank
    # HookedTransformer object

    n_heads = model.model_config.num_heads
    n_layers = model.model_config.num_layers
    d_head = model.model_config.key_size
    d_mlp = model.model_config.mlp_hidden_size
    act_fn = "relu"
    normalization_type = "LN" if model.model_config.layer_norm else None
    attention_type = "causal" if model.model_config.causal else "bidirectional"

    n_ctx = model.params["pos_embed"]["embeddings"].shape[0]
    # Equivalent to length of vocab, with BOS and PAD at the end
    d_vocab = model.params["token_embed"]["embeddings"].shape[0]
    # Residual stream width, I don't know of an easy way to infer it from the above
    # config.
    d_model = model.params["token_embed"]["embeddings"].shape[1]

    if tracr_task_key == "reverse":
        # Equivalent to length of vocab, WITHOUT BOS and PAD at the end because we never
        # care about these outputs
        d_vocab_out = model.params["token_embed"]["embeddings"].shape[0] - 2
    elif tracr_task_key == "xproportion":
        # This task outputs a real number, so we only need the first residual dimension
        d_vocab_out = 1

    cfg = HookedTransformerConfig(
        model_name=f"tracr-{tracr_task_key}",
        n_layers=n_layers,
        d_model=d_model,
        d_head=d_head,
        n_ctx=n_ctx,
        d_vocab=d_vocab,
        d_vocab_out=d_vocab_out,
        d_mlp=d_mlp,
        n_heads=n_heads,
        act_fn=act_fn,
        attention_dir=attention_type,
        normalization_type=normalization_type,
        use_attn_result=True,
        use_split_qkv_input=True,
        device=device,
    )
    tl_model = HookedTransformer(cfg)
    if "use_hook_mlp_in" in tl_model.cfg.to_dict():
        tl_model.set_use_hook_mlp_in(True)

    # Extract the state dict, and do some reshaping so that everything has a n_heads
    # dimension
    sd = {}
    sd["pos_embed.W_pos"] = model.params["pos_embed"]["embeddings"]
    sd["embed.W_E"] = model.params["token_embed"]["embeddings"]
    # Equivalent to max_seq_len plus one, for the BOS

    # The unembed is just a projection onto the first few elements of the residual
    # stream, these store output tokens
    # This is a NumPy array, the rest are Jax Arrays, but w/e it's fine.
    sd["unembed.W_U"] = np.eye(d_model, d_vocab_out)

    for lyr in range(n_layers):
        sd[f"blocks.{lyr}.attn.W_K"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/key"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.b_K"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/key"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.W_Q"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/query"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.b_Q"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/query"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.W_V"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/value"]["w"],
            "d_model (n_heads d_head) -> n_heads d_model d_head",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.b_V"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/value"]["b"],
            "(n_heads d_head) -> n_heads d_head",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.W_O"] = einops.rearrange(
            model.params[f"transformer/layer_{lyr}/attn/linear"]["w"],
            "(n_heads d_head) d_model -> n_heads d_head d_model",
            d_head=d_head,
            n_heads=n_heads,
        )
        sd[f"blocks.{lyr}.attn.b_O"] = model.params[
            f"transformer/layer_{lyr}/attn/linear"
        ]["b"]

        sd[f"blocks.{lyr}.mlp.W_in"] = model.params[
            f"transformer/layer_{lyr}/mlp/linear_1"
        ]["w"]
        sd[f"blocks.{lyr}.mlp.b_in"] = model.params[
            f"transformer/layer_{lyr}/mlp/linear_1"
        ]["b"]
        sd[f"blocks.{lyr}.mlp.W_out"] = model.params[
            f"transformer/layer_{lyr}/mlp/linear_2"
        ]["w"]
        sd[f"blocks.{lyr}.mlp.b_out"] = model.params[
            f"transformer/layer_{lyr}/mlp/linear_2"
        ]["b"]

    # Convert weights to tensors and load into the tl_model

    for k, v in sd.items():
        # I cannot figure out a neater way to go from a Jax array to a numpy array lol
        sd[k] = torch.tensor(np.array(v))

    tl_model.load_state_dict(sd, strict=False)
    return tl_model, model