joaogante HF staff commited on
Commit
2d2dd9a
1 Parent(s): 8445393

visual tweaks

Browse files
Files changed (2) hide show
  1. .gitignore +169 -0
  2. app.py +30 -28
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+
168
+ # ruff
169
+ .ruff_cache
app.py CHANGED
@@ -1,10 +1,14 @@
1
  from threading import Thread
2
  from functools import lru_cache
3
 
 
4
  import gradio as gr
5
  from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
6
 
7
 
 
 
 
8
  @lru_cache(maxsize=1) # only cache the latest model
9
  def get_model_and_tokenizer(model_id):
10
  config = AutoConfig.from_pretrained(model_id)
@@ -14,21 +18,22 @@ def get_model_and_tokenizer(model_id):
14
  model = AutoModelForCausalLM.from_pretrained(model_id)
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
17
  return model, tokenizer
18
 
19
 
20
- def run_generation(model_id, user_text, top_p, temperature, top_k, chat_counter, max_new_tokens, history):
21
  if history is None:
22
  history = []
23
- history.append[[user_text, ""]]
24
 
25
  # Get the model and tokenizer, and tokenize the user text.
26
  model, tokenizer = get_model_and_tokenizer(model_id)
27
- model_inputs = tokenizer([user_text], return_tensors="pt")
28
 
29
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
30
  # in the main thread.
31
- streamer = TextIteratorStreamer(tokenizer)
32
  generate_kwargs = dict(
33
  model_inputs,
34
  streamer=streamer,
@@ -52,26 +57,32 @@ def reset_textbox():
52
  return gr.update(value='')
53
 
54
 
55
- title = """<h1 align="center">🔥Transformers + Gradio 🚀Streaming🚀</h1>"""
56
-
57
-
58
  with gr.Blocks(
59
  css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
60
  #chatbot {height: 520px; overflow: auto;}"""
61
  ) as demo:
62
- gr.HTML(title)
63
- demo_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming"
64
- img_src = "https://bit.ly/3gLdBN6"
65
- button_desc = "Duplicate the Space to bypass queues, add hardware resources, or to use this demo as a template!"
66
- gr.HTML(f'''<center><a href="{demo_link}?duplicate=true"><img src="{img_src}" alt="Duplicate Space"></a>{button_desc}</center>''')
67
-
68
  with gr.Column(elem_id="col_container"):
 
 
 
 
 
 
 
 
 
 
 
 
69
  model_id = gr.Textbox(value='EleutherAI/pythia-410m', label="🤗 Hub Model repo")
70
- chatbot = gr.Chatbot(elem_id='chatbot')
71
  user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
72
- button = gr.Button()
73
 
74
- with gr.Accordion("Parameters", open=False):
 
 
 
75
  top_p = gr.Slider(
76
  minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
77
  )
@@ -81,21 +92,12 @@ with gr.Blocks(
81
  top_k = gr.Slider(
82
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
83
  )
84
- max_new_tokens = gr.Slider(
85
- minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens",
86
- )
87
 
88
  user_text.submit(
89
  run_generation,
90
- [model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot, chatbot],
91
- [chatbot, chatbot]
92
- )
93
- button.click(
94
- run_generation,
95
- [model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot, chatbot],
96
- [chatbot, chatbot]
97
  )
98
  button.click(reset_textbox, [], [user_text])
99
- user_text.submit(reset_textbox, [], [user_text])
100
 
101
- demo.queue().launch()
 
1
  from threading import Thread
2
  from functools import lru_cache
3
 
4
+ import torch
5
  import gradio as gr
6
  from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, TextIteratorStreamer
7
 
8
 
9
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+
12
  @lru_cache(maxsize=1) # only cache the latest model
13
  def get_model_and_tokenizer(model_id):
14
  config = AutoConfig.from_pretrained(model_id)
 
18
  model = AutoModelForCausalLM.from_pretrained(model_id)
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ model = model.to(torch_device)
22
  return model, tokenizer
23
 
24
 
25
+ def run_generation(model_id, user_text, top_p, temperature, top_k, max_new_tokens, history):
26
  if history is None:
27
  history = []
28
+ history.append([user_text, ""])
29
 
30
  # Get the model and tokenizer, and tokenize the user text.
31
  model, tokenizer = get_model_and_tokenizer(model_id)
32
+ model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
33
 
34
  # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
35
  # in the main thread.
36
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
37
  generate_kwargs = dict(
38
  model_inputs,
39
  streamer=streamer,
 
57
  return gr.update(value='')
58
 
59
 
 
 
 
60
  with gr.Blocks(
61
  css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
62
  #chatbot {height: 520px; overflow: auto;}"""
63
  ) as demo:
 
 
 
 
 
 
64
  with gr.Column(elem_id="col_container"):
65
+ demo_link = "https://huggingface.co/spaces/joaogante/chatbot_transformers_streaming"
66
+ gr.Markdown(
67
+ f"""
68
+ # 🤗 Transformers Gradio 🔥Streaming🔥
69
+ This demo showcases how to use the streaming feature of 🤗 Transformers with Gradio to generate text in real-time.
70
+ ⚠️ [Duplicate this Space]({demo_link}) if ⚠️
71
+ - You want to use a large model (> 1GB). Otherwise, this public space will become slow for others 💛
72
+ - You want to build your own app, using this demo as a template 🚀
73
+ - You want to bypass the queue and/or add hardware resources 👾
74
+ """
75
+ )
76
+
77
  model_id = gr.Textbox(value='EleutherAI/pythia-410m', label="🤗 Hub Model repo")
78
+ chatbot = gr.Chatbot(elem_id='chatbot', label="Message history")
79
  user_text = gr.Textbox(placeholder="Is pineapple a pizza topping?", label="Type an input and press Enter")
80
+ button = gr.Button(value="Clear message history")
81
 
82
+ with gr.Accordion("Generation Parameters", open=False):
83
+ max_new_tokens = gr.Slider(
84
+ minimum=1, maximum=1000, value=100, step=1, interactive=True, label="Max New Tokens",
85
+ )
86
  top_p = gr.Slider(
87
  minimum=0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
88
  )
 
92
  top_k = gr.Slider(
93
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
94
  )
 
 
 
95
 
96
  user_text.submit(
97
  run_generation,
98
+ [model_id, user_text, top_p, temperature, top_k, max_new_tokens, chatbot],
99
+ chatbot
 
 
 
 
 
100
  )
101
  button.click(reset_textbox, [], [user_text])
 
102
 
103
+ demo.queue(max_size=32).launch()