WuChengyue commited on
Commit
6e3152a
β€’
1 Parent(s): 99061a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -51,14 +51,10 @@ def convert_history(chat_history, max_input_length=1024):
51
  @torch.inference_mode()
52
  def instruct(instruction, max_token_output=1024):
53
  input_text = instruction
54
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
55
- input_ids = tokenizer(input_text, return_tensors='pt', truncation=False)
56
- input_ids["input_ids"] = input_ids["input_ids"].cuda()
57
- input_ids["attention_mask"] = input_ids["attention_mask"].cuda()
58
- generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=max_token_output, do_sample=False)
59
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
60
- thread.start()
61
- return streamer
62
 
63
 
64
  with gr.Blocks() as demo:
 
51
  @torch.inference_mode()
52
  def instruct(instruction, max_token_output=1024):
53
  input_text = instruction
54
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').cuda()
55
+ output_ids = model.generate(input_ids, max_length=max_token_output)[0]
56
+ output_str = tokenizer.decode(output_ids[input_ids.shape[-1]:])
57
+ return output_str.strip()
 
 
 
 
58
 
59
 
60
  with gr.Blocks() as demo: