joaogante HF staff commited on
Commit
1be532a
1 Parent(s): 7cb6518

limit to pythia 6.9b

Browse files
Files changed (1) hide show
  1. app.py +13 -25
app.py CHANGED
@@ -1,9 +1,8 @@
1
  from threading import Thread
2
- from functools import lru_cache
3
 
4
  import torch
5
  import gradio as gr
6
- from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
7
 
8
 
9
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -11,26 +10,16 @@ print("Running on device:", torch_device)
11
  print("CPU threads:", torch.get_num_threads())
12
 
13
 
14
- @lru_cache(maxsize=1) # only cache the latest model
15
- def get_model_and_tokenizer(model_id):
16
- config = AutoConfig.from_pretrained(model_id)
17
- if config.is_encoder_decoder:
18
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
19
- else:
20
- model = AutoModelForCausalLM.from_pretrained(model_id)
21
 
22
- tokenizer = AutoTokenizer.from_pretrained(model_id)
23
- model = model.to(torch_device)
24
- return model, tokenizer
25
 
26
-
27
- def run_generation(model_id, user_text, top_p, temperature, top_k, max_new_tokens, history):
28
  if history is None:
29
  history = []
30
  history.append([user_text, ""])
31
 
32
  # Get the model and tokenizer, and tokenize the user text.
33
- model, tokenizer = get_model_and_tokenizer(model_id)
34
  model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
35
 
36
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
@@ -66,17 +55,16 @@ with gr.Blocks(
66
  with gr.Column(elem_id="col_container"):
67
  duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true"
68
  gr.Markdown(
69
- f"""
70
- # 🤗 Transformers Gradio 🔥Streaming🔥
71
- This demo showcases how to use the streaming feature of 🤗 Transformers with Gradio to generate text in real-time.
72
- ⚠️ [Duplicate this Space]({duplicate_link}) if ⚠️
73
- - You want to use a large model (> 1GB). Otherwise, this public space will become slow for others 💛
74
- - You want to build your own app, using this demo as a template 🚀
75
- - You want to bypass the queue and/or add hardware resources 👾
76
- """
77
  )
78
 
79
- model_id = gr.Textbox(value='EleutherAI/pythia-410m', label="🤗 Hub Model repo")
80
  chatbot = gr.Chatbot(elem_id='chatbot', label="Message history")
81
  user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
82
  button = gr.Button(value="Clear message history")
@@ -97,7 +85,7 @@ with gr.Blocks(
97
 
98
  user_text.submit(
99
  run_generation,
100
- [model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot],
101
  chatbot
102
  )
103
  button.click(reset_textbox, [], [user_text])
 
1
  from threading import Thread
 
2
 
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
 
7
 
8
  torch_device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
  print("CPU threads:", torch.get_num_threads())
11
 
12
 
13
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-6.9b-deduped", load_in_8bit=True, device_map="auto")
14
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-6.9b-deduped")
 
 
 
 
 
15
 
 
 
 
16
 
17
+ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens, history):
 
18
  if history is None:
19
  history = []
20
  history.append([user_text, ""])
21
 
22
  # Get the model and tokenizer, and tokenize the user text.
 
23
  model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
24
 
25
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
 
55
  with gr.Column(elem_id="col_container"):
56
  duplicate_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming?duplicate=true"
57
  gr.Markdown(
58
+ "# 🤗 Transformers Gradio 🔥Streaming🔥\n"
59
+ "This demo showcases the use of the "
60
+ "[streaming feature](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming) "
61
+ "of 🤗 Transformers with Gradio to generate text in real-time. It uses "
62
+ "[EleutherAI/pythia-6.9b-deduped](https://huggingface.co/EleutherAI/pythia-6.9b-deduped), "
63
+ "a 6.9B parameter GPT-NeoX model by EleutherAI, loaded in 8-bit quantized form.\n\n"
64
+ f"Feel free to [duplicate this Space]({duplicate_link}) to try your own models or to use this space as a "
65
+ "template! 💛"
66
  )
67
 
 
68
  chatbot = gr.Chatbot(elem_id='chatbot', label="Message history")
69
  user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
70
  button = gr.Button(value="Clear message history")
 
85
 
86
  user_text.submit(
87
  run_generation,
88
+ [user_text, top_p, temperature, top_k, max_new_tokens, chatbot],
89
  chatbot
90
  )
91
  button.click(reset_textbox, [], [user_text])