FredZhang7 commited on
Commit
f41c6d3
1 Parent(s): bbffb85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -13
app.py CHANGED
@@ -11,6 +11,7 @@ model_path = hf_hub_download(repo_id="BlinkDL/rwkv-5-world", filename=f"{title}.
11
  model = RWKV(model=model_path, strategy='cpu bf16')
12
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
13
 
 
14
  def generate_prompt(instruction, input=None, history=None):
15
  # parse the chat history into a string of user and assistant messages
16
  history_str = ""
@@ -32,6 +33,7 @@ Response:"""
32
 
33
  Assistant:"""
34
 
 
35
  examples = [
36
  ["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.5, 0.5],
37
  ["Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", "", 300, 1.2, 0.5, 0.5, 0.5],
@@ -42,7 +44,7 @@ examples = [
42
  ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.5, 0.5],
43
  ]
44
 
45
- def evaluate(
46
  instruction,
47
  input=None,
48
  token_count=333,
@@ -59,9 +61,6 @@ def evaluate(
59
  token_stop = [0]) # stop generation whenever you see any token here
60
 
61
  instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
62
- no_history = (history is None)
63
- if no_history:
64
- input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
65
  ctx = generate_prompt(instruction, input, history)
66
  print(ctx + "\n")
67
 
@@ -89,8 +88,6 @@ def evaluate(
89
  tmp = pipeline.decode(all_tokens[out_last:])
90
  if '\ufffd' not in tmp:
91
  out_str += tmp
92
- if no_history:
93
- yield out_str.strip()
94
  out_last = i + 1
95
  if '\n\n' in out_str:
96
  break
@@ -98,11 +95,61 @@ def evaluate(
98
  del out
99
  del state
100
  gc.collect()
101
- if no_history:
102
- yield out_str.strip()
103
- else:
104
- history.append((instruction, out_str.strip()))
105
- return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  def user(message, chatbot):
108
  chatbot = chatbot or []
@@ -153,7 +200,7 @@ with gr.Blocks(title=title) as demo:
153
  presence_penalty = presence_penalty_chat.value
154
  count_penalty = count_penalty_chat.value
155
 
156
- response = evaluate(instruction, None, token_count, temperature, top_p, presence_penalty, count_penalty, history)
157
 
158
  history[-1][1] = response
159
  return history
@@ -179,7 +226,7 @@ with gr.Blocks(title=title) as demo:
179
  clear = gr.Button("Clear", variant="secondary")
180
  output = gr.Textbox(label="Output", lines=5)
181
  data = gr.Dataset(components=[instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
182
- submit.click(evaluate, [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], [output])
183
  clear.click(lambda: None, [], [output])
184
  data.click(lambda x: x, [data], [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct])
185
 
 
11
  model = RWKV(model=model_path, strategy='cpu bf16')
12
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
13
 
14
+
15
  def generate_prompt(instruction, input=None, history=None):
16
  # parse the chat history into a string of user and assistant messages
17
  history_str = ""
 
33
 
34
  Assistant:"""
35
 
36
+
37
  examples = [
38
  ["東京で訪れるべき素晴らしい場所とその紹介をいくつか挙げてください。", "", 300, 1.2, 0.5, 0.5, 0.5],
39
  ["Écrivez un programme Python pour miner 1 Bitcoin, avec des commentaires.", "", 300, 1.2, 0.5, 0.5, 0.5],
 
44
  ["You have $100, and your goal is to turn that into as much money as possible with AI and Machine Learning. Please respond with detailed plan.", "", 300, 1.2, 0.5, 0.5, 0.5],
45
  ]
46
 
47
+ def respond(
48
  instruction,
49
  input=None,
50
  token_count=333,
 
61
  token_stop = [0]) # stop generation whenever you see any token here
62
 
63
  instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
 
 
 
64
  ctx = generate_prompt(instruction, input, history)
65
  print(ctx + "\n")
66
 
 
88
  tmp = pipeline.decode(all_tokens[out_last:])
89
  if '\ufffd' not in tmp:
90
  out_str += tmp
 
 
91
  out_last = i + 1
92
  if '\n\n' in out_str:
93
  break
 
95
  del out
96
  del state
97
  gc.collect()
98
+ return out_str.strip()
99
+
100
+ def generator(
101
+ instruction,
102
+ input=None,
103
+ token_count=333,
104
+ temperature=1.0,
105
+ top_p=0.5,
106
+ presencePenalty = 0.5,
107
+ countPenalty = 0.5
108
+ ):
109
+ args = PIPELINE_ARGS(temperature = max(0.2, float(temperature)), top_p = float(top_p),
110
+ alpha_frequency = countPenalty,
111
+ alpha_presence = presencePenalty,
112
+ token_ban = [], # ban the generation of some tokens
113
+ token_stop = [0]) # stop generation whenever you see any token here
114
+
115
+ instruction = re.sub(r'\n{2,}', '\n', instruction).strip().replace('\r\n','\n')
116
+ input = re.sub(r'\n{2,}', '\n', input).strip().replace('\r\n','\n')
117
+ ctx = generate_prompt(instruction, input, history)
118
+ print(ctx + "\n")
119
+
120
+ all_tokens = []
121
+ out_last = 0
122
+ out_str = ''
123
+ occurrence = {}
124
+ state = None
125
+ for i in range(int(token_count)):
126
+ out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
127
+ for n in occurrence:
128
+ out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
129
+
130
+ token = pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
131
+ if token in args.token_stop:
132
+ break
133
+ all_tokens += [token]
134
+ for xxx in occurrence:
135
+ occurrence[xxx] *= 0.996
136
+ if token not in occurrence:
137
+ occurrence[token] = 1
138
+ else:
139
+ occurrence[token] += 1
140
+
141
+ tmp = pipeline.decode(all_tokens[out_last:])
142
+ if '\ufffd' not in tmp:
143
+ out_str += tmp
144
+ yield out_str.strip()
145
+ out_last = i + 1
146
+ if '\n\n' in out_str:
147
+ break
148
+
149
+ del out
150
+ del state
151
+ gc.collect()
152
+ yield out_str.strip()
153
 
154
  def user(message, chatbot):
155
  chatbot = chatbot or []
 
200
  presence_penalty = presence_penalty_chat.value
201
  count_penalty = count_penalty_chat.value
202
 
203
+ response = respond(instruction, None, token_count, temperature, top_p, presence_penalty, count_penalty, history)
204
 
205
  history[-1][1] = response
206
  return history
 
226
  clear = gr.Button("Clear", variant="secondary")
227
  output = gr.Textbox(label="Output", lines=5)
228
  data = gr.Dataset(components=[instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], samples=examples, label="Example Instructions", headers=["Instruction", "Input", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
229
+ submit.click(generator, [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct], [output])
230
  clear.click(lambda: None, [], [output])
231
  data.click(lambda x: x, [data], [instruction, input_instruct, token_count_instruct, temperature_instruct, top_p_instruct, presence_penalty_instruct, count_penalty_instruct])
232