Bingsu commited on
Commit
7134ebe
โ€ข
1 Parent(s): 0e3f536

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +95 -0
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - ko
5
+ tags:
6
+ - rwkv
7
+ - KoRWKV
8
+ ---
9
+
10
+ # KoRWKV
11
+
12
+ [RWKV-Runner](https://github.com/josStorer/RWKV-Runner)์—์„œ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด ๋ณ€ํ™˜ํ•œ ๋ชจ๋ธ ํŒŒ์ผ
13
+
14
+ - [beomi/KoAlpaca-KoRWKV-6B](https://huggingface.co/beomi/KoAlpaca-KoRWKV-6B)
15
+ - [beomi/KoRWKV-6B](https://huggingface.co/beomi/KoRWKV-6B)
16
+
17
+ ```py
18
+ import re
19
+
20
+ import torch
21
+
22
+ from transformers import RwkvForCausalLM
23
+
24
+ def convert_state_dict(state_dict):
25
+ state_dict_keys = list(state_dict.keys())
26
+ for name in state_dict_keys:
27
+ weight = state_dict.pop(name)
28
+ # emb -> embedding
29
+ if name.startswith("emb."):
30
+ name = name.replace("emb.", "embeddings.")
31
+ # ln_0 -> pre_ln (only present at block 0)
32
+ if name.startswith("blocks.0.ln0"):
33
+ name = name.replace("blocks.0.ln0", "blocks.0.pre_ln")
34
+ # att -> attention
35
+ name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name)
36
+ # ffn -> feed_forward
37
+ name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name)
38
+ # time_mix_k -> time_mix_key and reshape
39
+ if name.endswith(".time_mix_k"):
40
+ name = name.replace(".time_mix_k", ".time_mix_key")
41
+ # time_mix_v -> time_mix_value and reshape
42
+ if name.endswith(".time_mix_v"):
43
+ name = name.replace(".time_mix_v", ".time_mix_value")
44
+ # time_mix_r -> time_mix_key and reshape
45
+ if name.endswith(".time_mix_r"):
46
+ name = name.replace(".time_mix_r", ".time_mix_receptance")
47
+
48
+ if name != "head.weight":
49
+ name = "rwkv." + name
50
+
51
+ state_dict[name] = weight
52
+ return state_dict
53
+
54
+
55
+ def revert_state_dict(state_dict):
56
+ state_dict_keys = list(state_dict.keys())
57
+ for name in state_dict_keys:
58
+ weight = state_dict.pop(name)
59
+ name = name.removeprefix("rwkv.")
60
+
61
+ # emb -> embedding
62
+ if name.startswith("embeddings."):
63
+ name = name.replace("embeddings.", "emb.")
64
+ # ln_0 -> pre_ln (only present at block 0)
65
+ if name.startswith("blocks.0.pre_ln"):
66
+ name = name.replace("blocks.0.pre_ln", "blocks.0.ln0")
67
+ # att -> attention
68
+ name = re.sub(r"blocks\.(\d+)\.attention", r"blocks.\1.att", name)
69
+ # ffn -> feed_forward
70
+ name = re.sub(r"blocks\.(\d+)\.feed_forward", r"blocks.\1.ffn", name)
71
+ # time_mix_k -> time_mix_key and reshape
72
+ if name.endswith(".time_mix_key"):
73
+ name = name.replace(".time_mix_key", ".time_mix_k")
74
+ # time_mix_v -> time_mix_value and reshape
75
+ if name.endswith(".time_mix_value"):
76
+ name = name.replace(".time_mix_value", ".time_mix_v")
77
+ # time_mix_r -> time_mix_key and reshape
78
+ if name.endswith(".time_mix_receptance"):
79
+ name = name.replace(".time_mix_receptance", ".time_mix_r")
80
+
81
+ state_dict[name] = weight
82
+ return state_dict
83
+
84
+
85
+ if __name__ == "__main__":
86
+ # repo = "beomi/KoRWKV-6B"
87
+ repo = "beomi/KoAlpaca-KoRWKV-6B"
88
+ model = RwkvForCausalLM.from_pretrained(repo, torch_dtype=torch.bfloat16)
89
+
90
+ state_dict = model.state_dict()
91
+ converted = revert_state_dict(state_dict)
92
+ name = repo.split("/")[-1] + ".bf16.pth"
93
+
94
+ torch.save(converted, name)
95
+ ```