--- license: apache-2.0 language: - ko tags: - rwkv - KoRWKV --- # KoRWKV [RWKV-Runner](https://github.com/josStorer/RWKV-Runner)에서 사용하기 위해 변환한 모델 파일 - [beomi/KoAlpaca-KoRWKV-6B](https://huggingface.co/beomi/KoAlpaca-KoRWKV-6B) - [beomi/KoRWKV-6B](https://huggingface.co/beomi/KoRWKV-6B) ```py import re import torch from transformers import RwkvForCausalLM def convert_state_dict(state_dict): state_dict_keys = list(state_dict.keys()) for name in state_dict_keys: weight = state_dict.pop(name) # emb -> embedding if name.startswith("emb."): name = name.replace("emb.", "embeddings.") # ln_0 -> pre_ln (only present at block 0) if name.startswith("blocks.0.ln0"): name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") # att -> attention name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) # ffn -> feed_forward name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) # time_mix_k -> time_mix_key and reshape if name.endswith(".time_mix_k"): name = name.replace(".time_mix_k", ".time_mix_key") # time_mix_v -> time_mix_value and reshape if name.endswith(".time_mix_v"): name = name.replace(".time_mix_v", ".time_mix_value") # time_mix_r -> time_mix_key and reshape if name.endswith(".time_mix_r"): name = name.replace(".time_mix_r", ".time_mix_receptance") if name != "head.weight": name = "rwkv." + name state_dict[name] = weight return state_dict def revert_state_dict(state_dict): state_dict_keys = list(state_dict.keys()) for name in state_dict_keys: weight = state_dict.pop(name) name = name.removeprefix("rwkv.") # emb -> embedding if name.startswith("embeddings."): name = name.replace("embeddings.", "emb.") # ln_0 -> pre_ln (only present at block 0) if name.startswith("blocks.0.pre_ln"): name = name.replace("blocks.0.pre_ln", "blocks.0.ln0") # att -> attention name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name) # ffn -> feed_forward name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name) # time_mix_k -> time_mix_key and reshape if name.endswith(".time_mix_key"): name = name.replace(".time_mix_key", ".time_mix_k") # time_mix_v -> time_mix_value and reshape if name.endswith(".time_mix_value"): name = name.replace(".time_mix_value", ".time_mix_v") # time_mix_r -> time_mix_key and reshape if name.endswith(".time_mix_receptance"): name = name.replace(".time_mix_receptance", ".time_mix_r") state_dict[name] = weight return state_dict if __name__ == "__main__": # repo = "beomi/KoRWKV-6B" repo = "beomi/KoAlpaca-KoRWKV-6B" model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16) state_dict = model.state_dict() converted = revert_state_dict(state_dict) name = repo.split("/")[-1] + ".bf16.pth" torch.save(converted, name) ```