StarRing2022 commited on
Commit
c2117ba
1 Parent(s): 5806b76

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +16 -0
  2. hello_hf.py +51 -0
  3. ringrwkv.rar +3 -0
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_hidden_size": 2048,
3
+ "bos_token_id": 0,
4
+ "context_length": 1024,
5
+ "eos_token_id": 0,
6
+ "hidden_size": 2048,
7
+ "intermediate_size": 8192,
8
+ "layer_norm_epsilon": 1e-05,
9
+ "model_type": "rwkv",
10
+ "num_hidden_layers": 24,
11
+ "rescale_every": 6,
12
+ "tie_word_embeddings": false,
13
+ "transformers_version": "4.29.0",
14
+ "use_cache": true,
15
+ "vocab_size": 65536
16
+ }
hello_hf.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ringrwkv.configuration_rwkv_world import RwkvConfig
3
+ from ringrwkv.rwkv_tokenizer import TRIE_TOKENIZER
4
+ from ringrwkv.modehf_world import RwkvForCausalLM
5
+
6
+
7
+ model = RwkvForCausalLM.from_pretrained("RWKV-4-World-1.5B")
8
+ tokenizer = TRIE_TOKENIZER('./ringrwkv/rwkv_vocab_v20230424.txt')
9
+
10
+ text = "你叫什么名字?"
11
+
12
+ question = f'Question: {text.strip()}\n\nAnswer:'
13
+
14
+ input_ids = tokenizer.encode(question)
15
+ #print(tokenizer.decode(input_ids))
16
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
17
+
18
+ out = model.generate(input_ids,max_new_tokens=40)
19
+
20
+ #print(out[0])
21
+
22
+ outlist = out[0].tolist()
23
+
24
+ for i in outlist:
25
+ if i==0:
26
+ outlist.remove(i)
27
+
28
+ #print(outlist)
29
+ answer = tokenizer.decode(outlist)
30
+
31
+ # answer = tokenizer.decode([10464, 11685, 19126, 12605, 11021, 10399, 12176, 10464, 16533, 10722,
32
+ # 10250, 10349, 17728, 18025, 10080, 16738, 17728, 10464, 17879, 16503])
33
+ # answer = tokenizer.decode([53648, 59, 33, 10464, 11017, 10373, 10303, 11043, 11860, 19156,
34
+ # 261, 40301, 59, 33, 12605, 13091, 10250, 10283, 10370, 12137,
35
+ # 13133, 15752, 16728, 16537, 13499, 11496, 19137, 13734, 13191, 11043,
36
+ # 11860, 10080])
37
+ print(answer)
38
+
39
+ #print(input_ids.shape)
40
+ #rwkvoutput = model.forward(input_ids=input_ids,labels=input_ids) #loss,logits,state,hidden_states,attentions
41
+ # print("loss:")
42
+ # print(rwkvoutput.loss)
43
+ # print("logits:")
44
+ # print(rwkvoutput.logits)
45
+ # print("state:")
46
+ # print(rwkvoutput.state)
47
+ #print("last_hidden_state:")
48
+ # print(rwkvoutput.last_hidden_state)
49
+ # print("attentions:")
50
+ # print(rwkvoutput.attentions)
51
+
ringrwkv.rar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d0f6c7c7e365ad46b5969eeec655ce733514ca06e1485c6453766156f456032
3
+ size 261848