Spaces:
Running
on
Zero
Running
on
Zero
indiejoseph
commited on
Commit
•
35f8f29
1
Parent(s):
9021fd5
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
|
10 |
MAX_MAX_NEW_TOKENS = 4096
|
11 |
DEFAULT_MAX_NEW_TOKENS = 2048
|
@@ -39,7 +39,7 @@ def generate(
|
|
39 |
top_p: float = 0.9,
|
40 |
top_k: int = 50,
|
41 |
repetition_penalty: float = 1.2,
|
42 |
-
) ->
|
43 |
conversation = []
|
44 |
if system_prompt:
|
45 |
conversation.append({"role": "system", "content": system_prompt})
|
@@ -52,26 +52,20 @@ def generate(
|
|
52 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
53 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
54 |
input_ids = input_ids.to(model.device)
|
55 |
-
|
56 |
-
|
57 |
-
generate_kwargs = dict(
|
58 |
-
{"input_ids": input_ids},
|
59 |
-
streamer=streamer,
|
60 |
max_new_tokens=max_new_tokens,
|
61 |
do_sample=True,
|
62 |
top_p=top_p,
|
63 |
top_k=top_k,
|
64 |
temperature=temperature,
|
65 |
num_beams=1,
|
66 |
-
repetition_penalty=repetition_penalty
|
67 |
)
|
68 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
69 |
-
t.start()
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
yield "".join(outputs)
|
75 |
|
76 |
|
77 |
chat_interface = gr.ChatInterface(
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
9 |
|
10 |
MAX_MAX_NEW_TOKENS = 4096
|
11 |
DEFAULT_MAX_NEW_TOKENS = 2048
|
|
|
39 |
top_p: float = 0.9,
|
40 |
top_k: int = 50,
|
41 |
repetition_penalty: float = 1.2,
|
42 |
+
) -> str:
|
43 |
conversation = []
|
44 |
if system_prompt:
|
45 |
conversation.append({"role": "system", "content": system_prompt})
|
|
|
52 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
53 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
54 |
input_ids = input_ids.to(model.device)
|
55 |
+
output_ids = model.generate(
|
56 |
+
input_ids,
|
|
|
|
|
|
|
57 |
max_new_tokens=max_new_tokens,
|
58 |
do_sample=True,
|
59 |
top_p=top_p,
|
60 |
top_k=top_k,
|
61 |
temperature=temperature,
|
62 |
num_beams=1,
|
63 |
+
repetition_penalty=repetition_penalty
|
64 |
)
|
|
|
|
|
65 |
|
66 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
67 |
+
return response
|
68 |
+
|
|
|
69 |
|
70 |
|
71 |
chat_interface = gr.ChatInterface(
|