joaogante HF staff commited on
Commit
8445393
1 Parent(s): 5c81752

Updated chatbot

Browse files
Files changed (2) hide show
  1. app.py +89 -34
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,46 +1,101 @@
1
- import gradio as gr
2
  from threading import Thread
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, IteratorStreamer
 
 
 
 
4
 
 
 
 
 
 
 
 
5
 
6
- # Global variable loading
7
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
8
- print("Loading the model...")
9
- model = AutoModelForCausalLM.from_pretrained("gpt2")
10
- print("Done!")
11
 
12
 
13
- # Gradio app
14
- with gr.Blocks() as demo:
15
- def user(user_message, history):
16
- return "", history + [[user_message, None]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- chatbot = gr.Chatbot()
19
- msg = gr.Textbox()
20
- clear = gr.Button("Clear")
 
 
21
 
22
- def update_chatbot(history):
23
- user_query = history[-1][0]
24
- history[-1][1] = ""
25
- model_inputs = tokenizer([user_query], return_tensors="pt")
26
 
27
- # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
28
- # in the main thread.
29
- streamer = IteratorStreamer(tokenizer)
30
- generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200, do_sample=True)
31
- t = Thread(target=model.generate, kwargs=generate_kwargs)
32
- t.start()
33
 
34
- # Pull the generated text from the streamer, and update the chatbot.
35
- for new_text in streamer:
36
- history[-1][1] += new_text
37
- yield history
38
- return history
39
 
40
- msg.submit(user, [msg, chatbot], [msg, chatbot]).then(
41
- update_chatbot, chatbot, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
- clear.click(lambda: None, None, chatbot)
 
44
 
45
- demo.queue()
46
- demo.launch()
 
 
1
  from threading import Thread
2
+ from functools import lru_cache
3
+
4
+ import gradio as gr
5
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
+
7
 
8
+ @lru_cache(maxsize=1) # only cache the latest model
9
+ def get_model_and_tokenizer(model_id):
10
+ config = AutoConfig.from_pretrained(model_id)
11
+ if config.is_encoder_decoder:
12
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
13
+ else:
14
+ model = AutoModelForCausalLM.from_pretrained(model_id)
15
 
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ return model, tokenizer
 
 
 
18
 
19
 
20
+ def run_generation(model_id, user_text, top_p, temperature, top_k, chat_counter, max_new_tokens, history):
21
+ if history is None:
22
+ history = []
23
+ history.append[[user_text, ""]]
24
+
25
+ # Get the model and tokenizer, and tokenize the user text.
26
+ model, tokenizer = get_model_and_tokenizer(model_id)
27
+ model_inputs = tokenizer([user_text], return_tensors="pt")
28
+
29
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
30
+ # in the main thread.
31
+ streamer = TextIteratorStreamer(tokenizer)
32
+ generate_kwargs = dict(
33
+ model_inputs,
34
+ streamer=streamer,
35
+ max_new_tokens=max_new_tokens,
36
+ do_sample=True,
37
+ top_p=top_p,
38
+ temperature=temperature,
39
+ top_k=top_k
40
+ )
41
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
42
+ t.start()
43
 
44
+ # Pull the generated text from the streamer, and update the chatbot.
45
+ for new_text in streamer:
46
+ history[-1][1] += new_text
47
+ yield history
48
+ return history
49
 
 
 
 
 
50
 
51
+ def reset_textbox():
52
+ return gr.update(value='')
 
 
 
 
53
 
 
 
 
 
 
54
 
55
+ title = """<h1 align="center">🔥Transformers + Gradio 🚀Streaming🚀</h1>"""
56
+
57
+
58
+ with gr.Blocks(
59
+ css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
60
+ #chatbot {height: 520px; overflow: auto;}"""
61
+ ) as demo:
62
+ gr.HTML(title)
63
+ demo_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming"
64
+ img_src = "https://bit.ly/3gLdBN6"
65
+ button_desc = "Duplicate the Space to bypass queues, add hardware resources, or to use this demo as a template!"
66
+ gr.HTML(f'''<center><a href="{demo_link}?duplicate=true"><img src="{img_src}" alt="Duplicate Space"></a>{button_desc}</center>''')
67
+
68
+ with gr.Column(elem_id="col_container"):
69
+ model_id = gr.Textbox(value='EleutherAI/pythia-410m', label="🤗 Hub Model repo")
70
+ chatbot = gr.Chatbot(elem_id='chatbot')
71
+ user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
72
+ button = gr.Button()
73
+
74
+ with gr.Accordion("Parameters", open=False):
75
+ top_p = gr.Slider(
76
+ minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
77
+ )
78
+ temperature = gr.Slider(
79
+ minimum=0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature (set to 0 for Greedy Decoding)",
80
+ )
81
+ top_k = gr.Slider(
82
+ minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
83
+ )
84
+ max_new_tokens = gr.Slider(
85
+ minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens",
86
+ )
87
+
88
+ user_text.submit(
89
+ run_generation,
90
+ [model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot, chatbot],
91
+ [chatbot, chatbot]
92
+ )
93
+ button.click(
94
+ run_generation,
95
+ [model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot, chatbot],
96
+ [chatbot, chatbot]
97
  )
98
+ button.click(reset_textbox, [], [user_text])
99
+ user_text.submit(reset_textbox, [], [user_text])
100
 
101
+ demo.queue().launch()
 
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  torch
2
- git+https://github.com/gante/transformers.git@streamer_iterator # transformers from dev branch
 
1
  torch
2
+ git+https://github.com/huggingface/transformers.git # transformers from main (TextIteratorStreamer will be added in v4.28)