StarRing2022 commited on
Commit
2927278
1 Parent(s): 60fe38d

Upload 3 files

Browse files
Files changed (3) hide show
  1. generate_hf.py +72 -0
  2. hello_hf.py +51 -0
  3. ringrwkv.rar +3 -0
generate_hf.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ import transformers
4
+ import gradio as gr
5
+
6
+ from ringrwkv.configuration_rwkv_world import RwkvConfig
7
+ from ringrwkv.rwkv_tokenizer import TRIE_TOKENIZER
8
+ from ringrwkv.modehf_world import RwkvForCausalLM
9
+
10
+
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ else:
14
+ device = "cpu"
15
+
16
+ #放在本地工程根目录文件夹
17
+
18
+
19
+ model = RwkvForCausalLM.from_pretrained("RWKV-4-World-7B")
20
+ tokenizer = TRIE_TOKENIZER('./ringrwkv/rwkv_vocab_v20230424.txt')
21
+
22
+
23
+ #model= PeftModel.from_pretrained(model, "./lora-out")
24
+ model = model.to(device)
25
+
26
+
27
+ def evaluate(
28
+ instruction,
29
+ temperature=1,
30
+ top_p=0.7,
31
+ top_k = 0.1,
32
+ penalty_alpha = 0.1,
33
+ max_new_tokens=128,
34
+ ):
35
+
36
+ prompt = f'Question: {instruction.strip()}\n\nAnswer:'
37
+ input_ids = tokenizer.encode(prompt)
38
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
39
+ #out = model.generate(input_ids=input_ids.to(device),max_new_tokens=40)
40
+ out = model.generate(input_ids=input_ids.to(device),temperature=temperature,top_p=top_p,top_k=top_k,penalty_alpha=penalty_alpha,max_new_tokens=max_new_tokens)
41
+ outlist = out[0].tolist()
42
+ for i in outlist:
43
+ if i==0:
44
+ outlist.remove(i)
45
+ answer = tokenizer.decode(outlist)
46
+ return answer.strip()
47
+ #return answer.split("### Response:")[1].strip()
48
+
49
+
50
+ gr.Interface(
51
+ fn=evaluate,#接口函数
52
+ inputs=[
53
+ gr.components.Textbox(
54
+ lines=2, label="Instruction", placeholder="Tell me about alpacas."
55
+ ),
56
+ gr.components.Slider(minimum=0, maximum=2, value=1, label="Temperature"),
57
+ gr.components.Slider(minimum=0, maximum=1, value=0.7, label="Top p"),
58
+ gr.components.Slider(minimum=0, maximum=1, step=1, value=0.1, label="top_k"),
59
+ gr.components.Slider(minimum=0, maximum=1, step=1, value=0.1, label="penalty_alpha"),
60
+ gr.components.Slider(
61
+ minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
62
+ ),
63
+ ],
64
+ outputs=[
65
+ gr.inputs.Textbox(
66
+ lines=5,
67
+ label="Output",
68
+ )
69
+ ],
70
+ title="RWKV-World-Alpaca",
71
+ description="RWKV,Easy In HF.",
72
+ ).launch()
hello_hf.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ringrwkv.configuration_rwkv_world import RwkvConfig
3
+ from ringrwkv.rwkv_tokenizer import TRIE_TOKENIZER
4
+ from ringrwkv.modehf_world import RwkvForCausalLM
5
+
6
+
7
+ model = RwkvForCausalLM.from_pretrained("RWKV-4-World-7B")
8
+ tokenizer = TRIE_TOKENIZER('./ringrwkv/rwkv_vocab_v20230424.txt')
9
+
10
+ text = "你叫什么名字?"
11
+
12
+ question = f'Question: {text.strip()}\n\nAnswer:'
13
+
14
+ input_ids = tokenizer.encode(question)
15
+ #print(tokenizer.decode(input_ids))
16
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
17
+
18
+ out = model.generate(input_ids,max_new_tokens=40)
19
+
20
+ #print(out[0])
21
+
22
+ outlist = out[0].tolist()
23
+
24
+ for i in outlist:
25
+ if i==0:
26
+ outlist.remove(i)
27
+
28
+ #print(outlist)
29
+ answer = tokenizer.decode(outlist)
30
+
31
+ # answer = tokenizer.decode([10464, 11685, 19126, 12605, 11021, 10399, 12176, 10464, 16533, 10722,
32
+ # 10250, 10349, 17728, 18025, 10080, 16738, 17728, 10464, 17879, 16503])
33
+ # answer = tokenizer.decode([53648, 59, 33, 10464, 11017, 10373, 10303, 11043, 11860, 19156,
34
+ # 261, 40301, 59, 33, 12605, 13091, 10250, 10283, 10370, 12137,
35
+ # 13133, 15752, 16728, 16537, 13499, 11496, 19137, 13734, 13191, 11043,
36
+ # 11860, 10080])
37
+ print(answer)
38
+
39
+ #print(input_ids.shape)
40
+ #rwkvoutput = model.forward(input_ids=input_ids,labels=input_ids) #loss,logits,state,hidden_states,attentions
41
+ # print("loss:")
42
+ # print(rwkvoutput.loss)
43
+ # print("logits:")
44
+ # print(rwkvoutput.logits)
45
+ # print("state:")
46
+ # print(rwkvoutput.state)
47
+ #print("last_hidden_state:")
48
+ # print(rwkvoutput.last_hidden_state)
49
+ # print("attentions:")
50
+ # print(rwkvoutput.attentions)
51
+
ringrwkv.rar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d0f6c7c7e365ad46b5969eeec655ce733514ca06e1485c6453766156f456032
3
+ size 261848