aka7774 commited on
Commit
0be9b92
1 Parent(s): 7b90ef7

Upload 2 files

Browse files
Files changed (2) hide show
  1. fn.py +35 -37
  2. main.py +1 -1
fn.py CHANGED
@@ -156,6 +156,19 @@ def chatinterface_to_messages(message, history):
156
 
157
  return messages
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def chat(message, history = [], instruction = None, args = {}):
160
  global tokenizer, model, cfg
161
 
@@ -168,20 +181,17 @@ def chat(message, history = [], instruction = None, args = {}):
168
 
169
  model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
170
 
171
- if 'fastapi' not in args or 'stream' in args and args['stream']:
172
- streamer = TextIteratorStreamer(
173
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
174
- )
175
 
176
  generate_kwargs = dict(
177
  model_inputs,
178
  do_sample=True,
 
 
179
  )
180
 
181
- if 'fastapi' not in args or 'stream' in args and args['stream']:
182
- generate_kwargs['streamer'] = streamer
183
- generate_kwargs['num_beams'] = 1
184
-
185
  for k in [
186
  'max_new_tokens',
187
  'temperature',
@@ -192,33 +202,21 @@ def chat(message, history = [], instruction = None, args = {}):
192
  if cfg[k]:
193
  generate_kwargs[k] = cfg[k]
194
 
195
- if 'fastapi' not in args or 'stream' in args and args['stream']:
196
- t = Thread(target=model.generate, kwargs=generate_kwargs)
197
- t.start()
198
-
199
- model_output = ""
200
- for new_text in streamer:
201
- model_output += new_text
202
- if 'fastapi' in args:
203
- # fastapiは差分だけを返して欲しい
204
- yield new_text
205
- else:
206
- # gradioは常に全文を返して欲しい
207
- yield model_output
208
-
209
- outputs = model.generate(**generate_kwargs)
210
- content = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
211
  return content
212
-
213
- def apply_template(messages):
214
- global tokenizer, cfg
215
-
216
- if cfg['chat_template']:
217
- tokenizer.chat_template = cfg['chat_template']
218
-
219
- if type(messages) is str:
220
- if cfg['inst_template']:
221
- return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
222
- return cfg['instruction'].format(input=messages)
223
- if type(messages) is list:
224
- return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
 
156
 
157
  return messages
158
 
159
+ def apply_template(messages):
160
+ global tokenizer, cfg
161
+
162
+ if cfg['chat_template']:
163
+ tokenizer.chat_template = cfg['chat_template']
164
+
165
+ if type(messages) is str:
166
+ if cfg['inst_template']:
167
+ return cfg['inst_template'].format(instruction=cfg['instruction'], input=messages)
168
+ return cfg['instruction'].format(input=messages)
169
+ if type(messages) is list:
170
+ return tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)
171
+
172
  def chat(message, history = [], instruction = None, args = {}):
173
  global tokenizer, model, cfg
174
 
 
181
 
182
  model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
183
 
184
+ streamer = TextIteratorStreamer(
185
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True,
186
+ )
 
187
 
188
  generate_kwargs = dict(
189
  model_inputs,
190
  do_sample=True,
191
+ streamer=streamer,
192
+ num_beams=1,
193
  )
194
 
 
 
 
 
195
  for k in [
196
  'max_new_tokens',
197
  'temperature',
 
202
  if cfg[k]:
203
  generate_kwargs[k] = cfg[k]
204
 
205
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
206
+ t.start()
207
+
208
+ model_output = ""
209
+ for new_text in streamer:
210
+ model_output += new_text
211
+ if 'fastapi' in args:
212
+ # fastapiは差分だけを返して欲しい
213
+ yield new_text
214
+ else:
215
+ # gradioは常に全文を返して欲しい
216
+ yield model_output
217
+
218
+ def infer(message, history = [], instruction = None, args = {}):
219
+ content = ''
220
+ for s in chat(message, history, instruction, args):
221
+ content += s
222
  return content
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -40,5 +40,5 @@ async def api_infer(args: dict):
40
  media_type="text/event-stream",
41
  )
42
  else:
43
- content = fn.chat(args['input'], [], args['instruct'], args)
44
  return {'content': content}
 
40
  media_type="text/event-stream",
41
  )
42
  else:
43
+ content = fn.infer(args['input'], [], args['instruct'], args)
44
  return {'content': content}