joaogante HF staff commited on
Commit
183e675
1 Parent(s): 2503b95

haha tokens go brrr

Browse files
Files changed (3) hide show
  1. __pycache__/app.cpython-310.pyc +0 -0
  2. app.py +33 -11
  3. requirements.txt +1 -1
__pycache__/app.cpython-310.pyc ADDED
Binary file (1.38 kB). View file
 
app.py CHANGED
@@ -1,24 +1,46 @@
1
  import gradio as gr
2
- import random
3
- import time
4
 
 
 
 
 
 
 
 
 
 
5
  with gr.Blocks() as demo:
 
 
 
6
  chatbot = gr.Chatbot()
7
  msg = gr.Textbox()
8
  clear = gr.Button("Clear")
9
 
10
- def user(user_message, history):
11
- return "", history + [[user_message, None]]
 
 
 
 
 
 
 
 
 
12
 
13
- def bot(history):
14
- bot_message = random.choice(["Yes", "No"])
15
- history[-1][1] = bot_message
16
- time.sleep(1)
17
  return history
18
 
19
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
20
- bot, chatbot, chatbot
21
  )
22
- clear.click(lambda: None, None, chatbot, queue=False)
23
 
 
24
  demo.launch()
 
1
  import gradio as gr
2
+ from threading import Thread
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, IteratorStreamer
4
 
5
+
6
+ # Global variable loading
7
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
8
+ print("Loading the model...")
9
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
10
+ print("Done!")
11
+
12
+
13
+ # Gradio app
14
  with gr.Blocks() as demo:
15
+ def user(user_message, history):
16
+ return "", history + [[user_message, None]]
17
+
18
  chatbot = gr.Chatbot()
19
  msg = gr.Textbox()
20
  clear = gr.Button("Clear")
21
 
22
+ def update_chatbot(history):
23
+ user_query = history[-1][0]
24
+ history[-1][1] = ""
25
+ model_inputs = tokenizer([user_query], return_tensors="pt")
26
+
27
+ # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
28
+ # in the main thread.
29
+ streamer = IteratorStreamer(tokenizer)
30
+ generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200, do_sample=True)
31
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
32
+ t.start()
33
 
34
+ # Pull the generated text from the streamer, and update the chatbot.
35
+ for new_text in streamer:
36
+ history[-1][1] += new_text
37
+ yield history
38
  return history
39
 
40
+ msg.submit(user, [msg, chatbot], [msg, chatbot]).then(
41
+ update_chatbot, chatbot, chatbot
42
  )
43
+ clear.click(lambda: None, None, chatbot)
44
 
45
+ demo.queue()
46
  demo.launch()
requirements.txt CHANGED
@@ -1 +1 @@
1
- git+https://github.com/huggingface/transformers.git # transformers from `main`
 
1
+ git+https://github.com/gante/transformers.git@streamer_iterator # transformers from dev branch