joaogante HF staff commited on
Commit
c81e522
1 Parent(s): 1be532a

add chatbot functionality. add dependencies. update model

Browse files
Files changed (2) hide show
  1. app.py +47 -20
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,25 +2,40 @@ from threading import Thread
2
 
3
  import torch
4
  import gradio as gr
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
-
7
 
 
8
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
  print("Running on device:", torch_device)
10
  print("CPU threads:", torch.get_num_threads())
11
 
12
 
13
- model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-6.9b-deduped", load_in_8bit=True, device_map="auto")
14
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-6.9b-deduped")
15
 
16
 
17
- def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, history):
18
  if history is None:
19
  history = []
20
  history.append([user_text, ""])
21
 
22
- # Get the model and tokenizer, and tokenize the user text.
23
- model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
26
  # in the main thread.
@@ -55,29 +70,36 @@ with gr.Blocks(
55
  with gr.Column(elem_id="col_container"):
56
  duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true"
57
  gr.Markdown(
58
- "# 🤗 Transformers Gradio 🔥Streaming🔥\n"
59
  "This demo showcases the use of the "
60
  "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
61
- "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
62
- "[EleutherAI/pythia-6.9b-deduped](https://huggingface.co/EleutherAI/pythia-6.9b-deduped), "
63
- "a 6.9B parameter GPT-NeoX model by EleutherAI, loaded in 8-bit quantized form.\n\n"
64
- f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or to use this space as a "
65
  "template! 💛"
66
  )
67
 
68
- chatbot = gr.Chatbot(elem_id='chatbot', label="Message history")
69
- user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
70
- button = gr.Button(value="Clear message history")
 
 
 
 
 
 
71
 
72
  with gr.Accordion("Generation Parameters", open=False):
 
73
  max_new_tokens = gr.Slider(
74
- minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens",
75
  )
76
  top_p = gr.Slider(
77
- minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
78
  )
79
  temperature = gr.Slider(
80
- minimum=0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)",
81
  )
82
  top_k = gr.Slider(
83
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
@@ -85,9 +107,14 @@ with gr.Blocks(
85
 
86
  user_text.submit(
87
  run_generation,
88
- [user_text, top_p, temperature, top_k, max_new_tokens, chatbot],
 
 
 
 
 
89
  chatbot
90
  )
91
- button.click(reset_textbox, [], [user_text])
92
 
93
  demo.queue(max_size=32).launch()
 
2
 
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TextIteratorStreamer
 
6
 
7
+ model_id = "declare-lab/flan-alpaca-xl"
8
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
9
  print("Running on device:", torch_device)
10
  print("CPU threads:", torch.get_num_threads())
11
 
12
 
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
14
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
15
 
16
 
17
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, use_history, history):
18
  if history is None:
19
  history = []
20
  history.append([user_text, ""])
21
 
22
+ # Get the model and tokenizer, and tokenize the user text. If `use_history` is True, we use the chatbot history
23
+ if use_history:
24
+ user_name, assistant_name, sep = "User: ", "Assistant: ", "\n"
25
+ past = []
26
+ for data in history:
27
+ user_data, model_data = data
28
+
29
+ if not user_data.startswith(user_name):
30
+ user_data = user_name + user_data
31
+ if not model_data.startswith(sep + assistant_name):
32
+ model_data = sep + assistant_name + model_data
33
+
34
+ past.append(user_data + model_data.rstrip() + sep)
35
+ text_input = "".join(past)
36
+ else:
37
+ text_input = user_text
38
+ model_inputs = tokenizer([text_input], return_tensors="pt").to(torch_device)
39
 
40
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
41
  # in the main thread.
 
70
  with gr.Column(elem_id="col_container"):
71
  duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true"
72
  gr.Markdown(
73
+ "# 🤗 Transformers 🔥Streaming🔥 on Gradio\n"
74
  "This demo showcases the use of the "
75
  "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
76
+ "of 🤗 Transformers with Gradio to generate text in real-time, as a chatbot. It uses "
77
+ f"[{model_id}](https://huggingface.co/{model_id}), "
78
+ "loaded in 8-bit quantized form.\n\n"
79
+ f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or use this space as a "
80
  "template! 💛"
81
  )
82
 
83
+ chatbot = gr.Chatbot(elem_id='chatbot', label="Chat history")
84
+ user_text = gr.Textbox(
85
+ placeholder="Write an email about an alpaca that likes flan",
86
+ label="Type an input and press Enter"
87
+ )
88
+
89
+ with gr.Row():
90
+ button_submit = gr.Button(value="Submit")
91
+ button_clear = gr.Button(value="Clear chat history")
92
 
93
  with gr.Accordion("Generation Parameters", open=False):
94
+ use_history = gr.Checkbox(value=False, label="Use chat history as prompt")
95
  max_new_tokens = gr.Slider(
96
+ minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
97
  )
98
  top_p = gr.Slider(
99
+ minimum=0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
100
  )
101
  temperature = gr.Slider(
102
+ minimum=0, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)",
103
  )
104
  top_k = gr.Slider(
105
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
 
107
 
108
  user_text.submit(
109
  run_generation,
110
+ [user_text, top_p, temperature, top_k, max_new_tokens, use_history, chatbot],
111
+ chatbot
112
+ )
113
+ button_submit.click(
114
+ run_generation,
115
+ [user_text, top_p, temperature, top_k, max_new_tokens, use_history, chatbot],
116
  chatbot
117
  )
118
+ button_clear.click(reset_textbox, [], [chatbot])
119
 
120
  demo.queue(max_size=32).launch()
requirements.txt CHANGED
@@ -1,2 +1,3 @@
 
1
  torch
2
  git+https://github.com/huggingface/transformers.git # transformers from main (TextIteratorStreamer will be added in v4.28)
 
1
+ accelerate
2
  torch
3
  git+https://github.com/huggingface/transformers.git # transformers from main (TextIteratorStreamer will be added in v4.28)