joaogante HF staff commited on
Commit
fdb003a
1 Parent(s): d509568

revert to simpler textbox

Browse files
Files changed (1) hide show
  1. app.py +36 -67
app.py CHANGED
@@ -14,28 +14,9 @@ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, devic
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
 
16
 
17
- def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_history, history):
18
- if history is None:
19
- history = []
20
- history.append([user_text, ""])
21
-
22
- # Get the model and tokenizer, and tokenize the user text. If `use_history` is True, we use the chatbot history
23
- if use_history:
24
- user_name, assistant_name, sep = "User: ", "Assistant: ", "\n"
25
- past = []
26
- for data in history:
27
- user_data, model_data = data
28
-
29
- if not user_data.startswith(user_name):
30
- user_data = user_name + user_data
31
- if not model_data.startswith(sep + assistant_name):
32
- model_data = sep + assistant_name + model_data
33
-
34
- past.append(user_data + model_data.rstrip() + sep)
35
- text_input = "".join(past)
36
- else:
37
- text_input = user_text
38
- model_inputs = tokenizer([text_input], return_tensors="pt").to(torch_device)
39
 
40
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
41
  # in the main thread.
@@ -52,69 +33,57 @@ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_his
52
  t = Thread(target=model.generate, kwargs=generate_kwargs)
53
  t.start()
54
 
55
- # Pull the generated text from the streamer, and update the chatbot.
 
56
  for new_text in streamer:
57
- history[-1][1] += new_text
58
- yield history
59
- return history
60
 
61
 
62
  def reset_textbox():
63
  return gr.update(value='')
64
 
65
 
66
- with gr.Blocks(
67
- css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
68
- #chatbot {height: 520px; overflow: auto;}"""
69
- ) as demo:
70
- with gr.Column(elem_id="col_container"):
71
- duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true"
72
- gr.Markdown(
73
- "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
74
- "This demo showcases the use of the "
75
- "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
76
- "of 🤗 Transformers with Gradio to generate text in real-time, as a chatbot. It uses "
77
- f"[{model_id}](https://huggingface.co/{model_id}), "
78
- "loaded in 8-bit quantized form.\n\n"
79
- f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
80
- "template! 💛"
81
- )
82
-
83
- chatbot = gr.Chatbot(elem_id='chatbot', label="Chat history")
84
- user_text = gr.Textbox(
85
- placeholder="Write an email about an alpaca that likes flan",
86
- label="Type an input and press Enter"
87
- )
88
-
89
- with gr.Row():
90
  button_submit = gr.Button(value="Submit")
91
- button_clear = gr.Button(value="Clear chat history")
92
 
93
- with gr.Accordion("Generation Parameters", open=False):
94
- use_history = gr.Checkbox(value=False, label="Use chat history as prompt")
95
  max_new_tokens = gr.Slider(
96
  minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
97
  )
98
  top_p = gr.Slider(
99
  minimum=0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
100
  )
101
- temperature = gr.Slider(
102
- minimum=0, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)",
103
- )
104
  top_k = gr.Slider(
105
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
106
  )
 
 
 
107
 
108
- user_text.submit(
109
- run_generation,
110
- [user_text, top_p, temperature, top_k, max_new_tokens, use_history, chatbot],
111
- chatbot
112
- )
113
- button_submit.click(
114
- run_generation,
115
- [user_text, top_p, temperature, top_k, max_new_tokens, use_history, chatbot],
116
- chatbot
117
- )
118
- button_clear.click(reset_textbox, [], [chatbot])
119
 
120
  demo.queue(max_size=32).launch(enable_queue=True)
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_id)
15
 
16
 
17
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
18
+ # Get the model and tokenizer, and tokenize the user text.
19
+ model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
22
  # in the main thread.
 
33
  t = Thread(target=model.generate, kwargs=generate_kwargs)
34
  t.start()
35
 
36
+ # Pull the generated text from the streamer, and update the model output.
37
+ model_output = ""
38
  for new_text in streamer:
39
+ model_output += new_text
40
+ yield model_output
41
+ return model_output
42
 
43
 
44
  def reset_textbox():
45
  return gr.update(value='')
46
 
47
 
48
+ with gr.Blocks() as demo:
49
+ duplicate_link = "https://huggingface.co/spaces/joaogante/transformers_streaming?duplicate=true"
50
+ gr.Markdown(
51
+ "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
52
+ "This demo showcases the use of the "
53
+ "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
54
+ "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
55
+ f"[{model_id}](https://huggingface.co/{model_id}), "
56
+ "loaded in 8-bit quantized form.\n\n"
57
+ f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
58
+ "template! 💛"
59
+ )
60
+
61
+ with gr.Row():
62
+ with gr.Column(scale=4):
63
+ user_text = gr.Textbox(
64
+ placeholder="Write an email about an alpaca that likes flan",
65
+ label="User input"
66
+ )
67
+ model_output = gr.Textbox(
68
+ label="Model output", lines=10, read_only=True
69
+ )
 
 
70
  button_submit = gr.Button(value="Submit")
 
71
 
72
+ with gr.Column(scale=1):
 
73
  max_new_tokens = gr.Slider(
74
  minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
75
  )
76
  top_p = gr.Slider(
77
  minimum=0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
78
  )
 
 
 
79
  top_k = gr.Slider(
80
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
81
  )
82
+ temperature = gr.Slider(
83
+ minimum=0, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature (0 = Greedy Decoding)",
84
+ )
85
 
86
+ user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
87
+ button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
 
 
 
 
 
 
 
 
 
88
 
89
  demo.queue(max_size=32).launch(enable_queue=True)