alfredplpl
commited on
Commit
•
e2f46ea
1
Parent(s):
be71825
Update app.py
Browse files
app.py
CHANGED
@@ -73,16 +73,27 @@ def chat_llm_jp_v2(message: str,
|
|
73 |
|
74 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
75 |
|
76 |
-
|
77 |
-
input_ids,
|
|
|
78 |
max_new_tokens=max_new_tokens,
|
79 |
do_sample=True,
|
80 |
-
top_p=0.95,
|
81 |
temperature=temperature,
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
# Gradio block
|
|
|
73 |
|
74 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
75 |
|
76 |
+
generate_kwargs = dict(
|
77 |
+
input_ids= input_ids,
|
78 |
+
streamer=streamer,
|
79 |
max_new_tokens=max_new_tokens,
|
80 |
do_sample=True,
|
|
|
81 |
temperature=temperature,
|
82 |
+
top_p=0.95,
|
83 |
+
repetition_penalty=1.1,
|
84 |
+
)
|
85 |
+
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
86 |
+
if temperature == 0:
|
87 |
+
generate_kwargs['do_sample'] = False
|
88 |
+
|
89 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
90 |
+
t.start()
|
91 |
+
|
92 |
+
outputs = []
|
93 |
+
for text in streamer:
|
94 |
+
outputs.append(text)
|
95 |
+
print(outputs)
|
96 |
+
yield "".join(outputs)
|
97 |
|
98 |
|
99 |
# Gradio block
|