alfredplpl commited on
Commit
e2f46ea
1 Parent(s): be71825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -73,16 +73,27 @@ def chat_llm_jp_v2(message: str,
73
 
74
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
75
 
76
- output = model.generate(
77
- input_ids,
 
78
  max_new_tokens=max_new_tokens,
79
  do_sample=True,
80
- top_p=0.95,
81
  temperature=temperature,
82
- repetition_penalty=1.05,
83
- )[0]
84
-
85
- return tokenizer.decode(output)
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  # Gradio block
 
73
 
74
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
75
 
76
+ generate_kwargs = dict(
77
+ input_ids= input_ids,
78
+ streamer=streamer,
79
  max_new_tokens=max_new_tokens,
80
  do_sample=True,
 
81
  temperature=temperature,
82
+ top_p=0.95,
83
+ repetition_penalty=1.1,
84
+ )
85
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
86
+ if temperature == 0:
87
+ generate_kwargs['do_sample'] = False
88
+
89
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
90
+ t.start()
91
+
92
+ outputs = []
93
+ for text in streamer:
94
+ outputs.append(text)
95
+ print(outputs)
96
+ yield "".join(outputs)
97
 
98
 
99
  # Gradio block