Mistral-7b-0.2 / README.md
tcapelle's picture
Update README.md
8ba580b verified
metadata
library_name: transformers
tags: []

Just a conversion from the shared model during the hackathon, not sure it is correct.

Conversion map:

state_dict_mapping = {
    "tok_embeddings.weight": "model.embed_tokens.weight",
    "norm.weight": "model.norm.weight",
    "output.weight": "lm_head.weight"
}

def map_layer(i):
    return  {
        f"layers.{i}.attention.wq.weight": f"model.layers.{i}.self_attn.q_proj.weight",
        f"layers.{i}.attention.wk.weight": f"model.layers.{i}.self_attn.k_proj.weight",
        f"layers.{i}.attention.wv.weight": f"model.layers.{i}.self_attn.v_proj.weight",
        f"layers.{i}.attention.wo.weight": f"model.layers.{i}.self_attn.o_proj.weight",
        f"layers.{i}.feed_forward.w1.weight": f"model.layers.{i}.mlp.gate_proj.weight",
        f"layers.{i}.feed_forward.w2.weight": f"model.layers.{i}.mlp.down_proj.weight",
        f"layers.{i}.feed_forward.w3.weight": f"model.layers.{i}.mlp.up_proj.weight",
        f"layers.{i}.attention_norm.weight": f"model.layers.{i}.input_layernorm.weight",
        f"layers.{i}.ffn_norm.weight": f"model.layers.{i}.post_attention_layernorm.weight",
    }

for i in range(32):
    state_dict_mapping.update(map_layer(i))