Spaces:
Runtime error
Runtime error
Add tons of suggetsions; Improved matching logic
Browse files- app.py +84 -119
- general_suggestions.py +156 -0
- model_suggestions.py +62 -0
- task_suggestions.py +85 -0
app.py
CHANGED
@@ -1,11 +1,16 @@
|
|
1 |
import gradio as gr
|
2 |
import huggingface_hub as hfh
|
3 |
from requests.exceptions import HTTPError
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
# =====================================================================================================================
|
6 |
# DATA
|
7 |
# =====================================================================================================================
|
8 |
-
# Dict with the tasks considered in this spaces, {
|
9 |
TASK_TYPES = {
|
10 |
"✍️ Text Generation": "txtgen",
|
11 |
"🤏 Summarization": "summ",
|
@@ -17,7 +22,7 @@ TASK_TYPES = {
|
|
17 |
"🌇 Image to Text": "img2txt",
|
18 |
}
|
19 |
|
20 |
-
# Dict matching all task types with their possible hub tags, {
|
21 |
HUB_TAGS = {
|
22 |
"txtgen": ("text-generation", "text2text-generation"),
|
23 |
"summ": ("summarization", "text-generation", "text2text-generation"),
|
@@ -31,16 +36,16 @@ HUB_TAGS = {
|
|
31 |
assert len(TASK_TYPES) == len(TASK_TYPES)
|
32 |
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values())
|
33 |
|
34 |
-
# Dict with the problems considered in this spaces, {problem:
|
35 |
PROBLEMS = {
|
36 |
-
"🤔 Baseline.
|
37 |
"😵 Crashes. I want to prevent my model from crashing again": "crashes",
|
38 |
-
"
|
39 |
-
"🌈 Interactivity. I would like a ChatGPT-like model": "interactity",
|
40 |
"📏 Length. I want to control the length of the output": "length",
|
41 |
-
"
|
|
|
|
|
42 |
"🏎 Speed! Make it faster!": "speed",
|
43 |
-
"❓ ??? Something else, looking for ideas": "other",
|
44 |
}
|
45 |
|
46 |
INIT_MARKDOWN = """
|
@@ -68,9 +73,16 @@ DEMO_MARKDOWN = """
|
|
68 |
"""
|
69 |
|
70 |
SUGGETIONS_HEADER = """
|
71 |
-
✨ Here is a list of suggestions for you -- click to expand ✨
|
72 |
"""
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
TASK_MODEL_MISMATCH = """
|
75 |
<details><summary>{count}. Select a model better suited for your task.</summary>
|
76 |
|
@@ -89,94 +101,9 @@ models as a starting point.
|
|
89 |
|
90 |
1. The tags of a model are defined by the community and are not always accurate. If you think the model is incorrectly
|
91 |
tagged or missing a tag, please open an issue on the [model card](https://huggingface.co/{model_name}/tree/main).
|
92 |
-
</details>
|
93 |
-
"""
|
94 |
-
|
95 |
-
SET_MAX_NEW_TOKENS = """
|
96 |
-
<details><summary>{count}. Control the maximum output length with `max_new_tokens`.</summary>
|
97 |
-
|
98 |
-
|
99 |
-
🤔 Why?
|
100 |
-
|
101 |
-
All text generation calls have a length-related stopping condition. Depending on the model and/or the tool you're
|
102 |
-
using to generate text, the default value may be too small or too large. I'd recommend ALWAYS setting this option.
|
103 |
-
|
104 |
-
|
105 |
-
🤗 How?
|
106 |
-
|
107 |
-
Our text generation interfaces accept a `max_new_tokens` option. Set it to define the maximum number of tokens
|
108 |
-
that can be generated.
|
109 |
-
|
110 |
-
😱 Caveats
|
111 |
-
|
112 |
-
1. Allowing a longer output doesn't necessarily mean that the model will generate longer outputs. By default,
|
113 |
-
the model will stop generating when it generates a special `eos_token_id` token.
|
114 |
-
2. You shouldn't set `max_new_tokens` to a value larger than the maximum sequence length of the model. If you need a
|
115 |
-
longer output, consider using a model with a larger maximum sequence length.
|
116 |
-
3. The longer the output, the longer it will take to generate.
|
117 |
-
</details>
|
118 |
-
"""
|
119 |
-
|
120 |
-
SET_MIN_LENGTH = """
|
121 |
-
<details><summary>{count}. Force a minimum output length with `min_new_tokens`.</summary>
|
122 |
-
|
123 |
-
|
124 |
-
🤔 Why?
|
125 |
|
126 |
-
|
127 |
-
forced to continue generating.
|
128 |
|
129 |
-
🤗 How?
|
130 |
-
|
131 |
-
Our text generation interfaces accept a `min_new_tokens` argument. Set it to prevent `eos_token_id` from being
|
132 |
-
generated until `min_new_tokens` tokens are generated.
|
133 |
-
|
134 |
-
😱 Caveats
|
135 |
-
|
136 |
-
1. The quality of the output may suffer if the model is forced to generate beyond its own original expectations.
|
137 |
-
2. `min_new_tokens` must be smaller than than `max_new_tokens` (see related tip).
|
138 |
-
</details>
|
139 |
-
"""
|
140 |
-
|
141 |
-
REMOVE_EOS_TOKEN = """
|
142 |
-
<details><summary>{count}. Prevent the model of halting generation by removing `eos_token_id`.</summary>
|
143 |
-
|
144 |
-
|
145 |
-
🤔 Why?
|
146 |
-
|
147 |
-
Text generation stops when the model generates a special `eos_token_id`. If there is no `eos_token_id`, the model can't
|
148 |
-
stop.
|
149 |
-
|
150 |
-
|
151 |
-
🤗 How?
|
152 |
-
|
153 |
-
Our text generation interfaces accept a `eos_token_id` argument. Set it to a null value (e.g., in Python,
|
154 |
-
`eos_token_id=None`) to prevent generation to stop before it reaches other stopping conditions.
|
155 |
-
|
156 |
-
😱 Caveats
|
157 |
-
|
158 |
-
1. The quality of the output may suffer if the model is forced to generate beyond its own original expectations.
|
159 |
-
</details>
|
160 |
-
"""
|
161 |
-
|
162 |
-
LIST_EOS_TOKEN = """
|
163 |
-
<details><summary>{count}. Add a stop word through `eos_token_id`.</summary>
|
164 |
-
|
165 |
-
|
166 |
-
🤔 Why?
|
167 |
-
|
168 |
-
Text generation stops when the model generates a special `eos_token_id`. Actually, this attribute can be a list of
|
169 |
-
tokens, which means you can define arbitrary stop words.
|
170 |
-
|
171 |
-
|
172 |
-
🤗 How?
|
173 |
-
|
174 |
-
Our text generation interfaces accept a `eos_token_id` argument. You can pass a list of tokens to make generation
|
175 |
-
stop in the presence of any of those tokens.
|
176 |
-
|
177 |
-
😱 Caveats
|
178 |
-
|
179 |
-
1. When passing a list of tokens, you probably shouldn't forget to include the default `eos_token_id` there.
|
180 |
</details>
|
181 |
"""
|
182 |
# =====================================================================================================================
|
@@ -185,50 +112,88 @@ stop in the presence of any of those tokens.
|
|
185 |
# =====================================================================================================================
|
186 |
# SUGGESTIONS LOGIC
|
187 |
# =====================================================================================================================
|
188 |
-
def is_valid_task_for_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
if model_name == "":
|
190 |
-
return
|
191 |
try:
|
192 |
model_tags = hfh.HfApi().model_info(model_name).tags
|
193 |
except HTTPError:
|
194 |
-
|
195 |
-
|
196 |
-
possible_tags = HUB_TAGS[TASK_TYPES[task_type]]
|
197 |
-
return any(tag in model_tags for tag in possible_tags)
|
198 |
|
199 |
|
200 |
def get_suggestions(task_type, model_name, problem_type):
|
201 |
# Check if the inputs were given
|
202 |
-
if task_type == ""
|
203 |
return INIT_MARKDOWN
|
204 |
|
205 |
-
suggestions =
|
206 |
counter = 0
|
|
|
|
|
|
|
|
|
207 |
|
208 |
# Check if the model is valid for the task. If not, return straight away
|
209 |
-
if not is_valid_task_for_model(
|
210 |
counter += 1
|
211 |
-
possible_tags = " ".join("`" + tag + "`" for tag in HUB_TAGS[
|
212 |
suggestions += TASK_MODEL_MISMATCH.format(
|
213 |
-
count=counter, model_name=model_name, task_type=
|
214 |
)
|
215 |
return suggestions
|
216 |
|
217 |
# Demo shortcut: only a few sections are working
|
218 |
-
if
|
219 |
return DEMO_MARKDOWN
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
# =====================================================================================================================
|
233 |
|
234 |
|
@@ -255,7 +220,7 @@ with demo:
|
|
255 |
value="",
|
256 |
)
|
257 |
task_type = gr.Dropdown(
|
258 |
-
label="
|
259 |
choices=[""] + list(TASK_TYPES.keys()),
|
260 |
interactive=True,
|
261 |
value="",
|
|
|
1 |
import gradio as gr
|
2 |
import huggingface_hub as hfh
|
3 |
from requests.exceptions import HTTPError
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
from general_suggestions import GENERAL_SUGGESTIONS
|
7 |
+
from model_suggestions import MODEL_SUGGESTIONS
|
8 |
+
from task_suggestions import TASK_SUGGESTIONS
|
9 |
|
10 |
# =====================================================================================================================
|
11 |
# DATA
|
12 |
# =====================================================================================================================
|
13 |
+
# Dict with the tasks considered in this spaces, {task: task tag}
|
14 |
TASK_TYPES = {
|
15 |
"✍️ Text Generation": "txtgen",
|
16 |
"🤏 Summarization": "summ",
|
|
|
22 |
"🌇 Image to Text": "img2txt",
|
23 |
}
|
24 |
|
25 |
+
# Dict matching all task types with their possible hub tags, {task tag: (possible, hub, tags)}
|
26 |
HUB_TAGS = {
|
27 |
"txtgen": ("text-generation", "text2text-generation"),
|
28 |
"summ": ("summarization", "text-generation", "text2text-generation"),
|
|
|
36 |
assert len(TASK_TYPES) == len(TASK_TYPES)
|
37 |
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values())
|
38 |
|
39 |
+
# Dict with the problems considered in this spaces, {problem: problem tag}
|
40 |
PROBLEMS = {
|
41 |
+
"🤔 Baseline. I'm getting gibberish and I want a baseline": "baseline",
|
42 |
"😵 Crashes. I want to prevent my model from crashing again": "crashes",
|
43 |
+
"🤥 Hallucinations. I would like to reduce them": "hallucinations",
|
|
|
44 |
"📏 Length. I want to control the length of the output": "length",
|
45 |
+
"🌈 Prompting. I want better outputs without changing my generation options": "prompting",
|
46 |
+
"😵💫 Repetitions. Make them stop make them stop": "repetitions",
|
47 |
+
"📈 Quality. I want better outputs without changing my prompt": "quality",
|
48 |
"🏎 Speed! Make it faster!": "speed",
|
|
|
49 |
}
|
50 |
|
51 |
INIT_MARKDOWN = """
|
|
|
73 |
"""
|
74 |
|
75 |
SUGGETIONS_HEADER = """
|
76 |
+
#### ✨ Here is a list of suggestions for you -- click to expand ✨
|
77 |
"""
|
78 |
|
79 |
+
PERFECT_MATCH_EMOJI = "✅"
|
80 |
+
POSSIBLE_MATCH_EMOJI = "❓"
|
81 |
+
MISSING_INPUTS = """
|
82 |
+
💡 You can filter suggestions with {} if you add more inputs. Suggestions with {} are a perfect match.
|
83 |
+
""".format(POSSIBLE_MATCH_EMOJI, PERFECT_MATCH_EMOJI)
|
84 |
+
|
85 |
+
# The space below is reserved for suggestions that require advanced logic and/or formatting
|
86 |
TASK_MODEL_MISMATCH = """
|
87 |
<details><summary>{count}. Select a model better suited for your task.</summary>
|
88 |
|
|
|
101 |
|
102 |
1. The tags of a model are defined by the community and are not always accurate. If you think the model is incorrectly
|
103 |
tagged or missing a tag, please open an issue on the [model card](https://huggingface.co/{model_name}/tree/main).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
+
_________________
|
|
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
</details>
|
108 |
"""
|
109 |
# =====================================================================================================================
|
|
|
112 |
# =====================================================================================================================
|
113 |
# SUGGESTIONS LOGIC
|
114 |
# =====================================================================================================================
|
115 |
+
def is_valid_task_for_model(model_tags, user_task):
|
116 |
+
if len(model_tags) == 0 or user_task == "":
|
117 |
+
return True # No model / no tags = no problem :)
|
118 |
+
|
119 |
+
possible_tags = HUB_TAGS[user_task]
|
120 |
+
return any(tag in model_tags for tag in possible_tags)
|
121 |
+
|
122 |
+
|
123 |
+
@lru_cache(maxsize=int(2e10))
|
124 |
+
def get_model_tags(model_name):
|
125 |
if model_name == "":
|
126 |
+
return []
|
127 |
try:
|
128 |
model_tags = hfh.HfApi().model_info(model_name).tags
|
129 |
except HTTPError:
|
130 |
+
model_tags = []
|
131 |
+
return model_tags
|
|
|
|
|
132 |
|
133 |
|
134 |
def get_suggestions(task_type, model_name, problem_type):
|
135 |
# Check if the inputs were given
|
136 |
+
if all([task_type == "", model_name == "", problem_type == ""]):
|
137 |
return INIT_MARKDOWN
|
138 |
|
139 |
+
suggestions = ""
|
140 |
counter = 0
|
141 |
+
model_tags = get_model_tags(model_name)
|
142 |
+
|
143 |
+
user_problem = PROBLEMS.get(problem_type, "")
|
144 |
+
user_task = TASK_TYPES.get(task_type, "")
|
145 |
|
146 |
# Check if the model is valid for the task. If not, return straight away
|
147 |
+
if not is_valid_task_for_model(model_tags, user_task):
|
148 |
counter += 1
|
149 |
+
possible_tags = " ".join("`" + tag + "`" for tag in HUB_TAGS[user_task])
|
150 |
suggestions += TASK_MODEL_MISMATCH.format(
|
151 |
+
count=counter, model_name=model_name, task_type=user_task, tags=possible_tags
|
152 |
)
|
153 |
return suggestions
|
154 |
|
155 |
# Demo shortcut: only a few sections are working
|
156 |
+
if user_problem not in ("", "length", "quality"):
|
157 |
return DEMO_MARKDOWN
|
158 |
|
159 |
+
# First: model-specific suggestions
|
160 |
+
has_model_specific_suggestions = False
|
161 |
+
match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or len(model_tags) == 0) else PERFECT_MATCH_EMOJI
|
162 |
+
for model_tag, problem_tags, suggestion in MODEL_SUGGESTIONS:
|
163 |
+
if user_problem == "" or user_problem in problem_tags:
|
164 |
+
if len(model_tags) == 0 or model_tag in model_tags:
|
165 |
+
counter += 1
|
166 |
+
suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
|
167 |
+
has_model_specific_suggestions = True
|
168 |
+
|
169 |
+
# Second: task-specific suggestions
|
170 |
+
has_task_specific_suggestions = False
|
171 |
+
match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or user_task == "") else PERFECT_MATCH_EMOJI
|
172 |
+
for task_tags, problem_tags, suggestion in TASK_SUGGESTIONS:
|
173 |
+
if user_problem == "" or user_problem in problem_tags:
|
174 |
+
if user_task == "" or user_task in task_tags:
|
175 |
+
counter += 1
|
176 |
+
suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
|
177 |
+
has_task_specific_suggestions = True
|
178 |
+
|
179 |
+
# Finally: general suggestions for the problem
|
180 |
+
has_problem_specific_suggestions = False
|
181 |
+
match_emoji = POSSIBLE_MATCH_EMOJI if user_problem == "" else PERFECT_MATCH_EMOJI
|
182 |
+
for problem_tags, suggestion in GENERAL_SUGGESTIONS:
|
183 |
+
if user_problem == "" or user_problem in problem_tags:
|
184 |
+
counter += 1
|
185 |
+
suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
|
186 |
+
has_problem_specific_suggestions = True
|
187 |
+
|
188 |
+
# Prepends needed bits
|
189 |
+
if (
|
190 |
+
(task_type == "" and has_task_specific_suggestions)
|
191 |
+
or (model_name == "" and has_model_specific_suggestions)
|
192 |
+
or (problem_type == "" and has_problem_specific_suggestions)
|
193 |
+
):
|
194 |
+
suggestions = MISSING_INPUTS + suggestions
|
195 |
+
|
196 |
+
return SUGGETIONS_HEADER + suggestions
|
197 |
# =====================================================================================================================
|
198 |
|
199 |
|
|
|
220 |
value="",
|
221 |
)
|
222 |
task_type = gr.Dropdown(
|
223 |
+
label="Which task are you working on?",
|
224 |
choices=[""] + list(TASK_TYPES.keys()),
|
225 |
interactive=True,
|
226 |
value="",
|
general_suggestions.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a file holding task and model agnostic suggestions.
|
3 |
+
|
4 |
+
How to add a new suggestion:
|
5 |
+
1. Add a new constant at the bottom of the file with your suggestion. Please try to follow the same format as the
|
6 |
+
existing suggestions.
|
7 |
+
2. Add a new entry to the `GENERAL_SUGGESTIONS`, with format `((problem tags,), suggestion constant)`.
|
8 |
+
a. See `app.py` for the existing problem tags.
|
9 |
+
c. Make sure the problem tags are a tuple.
|
10 |
+
"""
|
11 |
+
|
12 |
+
SET_MAX_NEW_TOKENS = """
|
13 |
+
<details><summary>{match_emoji} {count}. Control the maximum output length.</summary>
|
14 |
+
|
15 |
+
|
16 |
+
🤔 Why?
|
17 |
+
|
18 |
+
All text generation calls have a length-related stopping condition. Depending on the model and/or the tool you're
|
19 |
+
using to generate text, the default value may be too small or too large. I'd recommend ALWAYS setting this option.
|
20 |
+
|
21 |
+
|
22 |
+
🤗 How?
|
23 |
+
|
24 |
+
Our text generation interfaces accept a `max_new_tokens` option. Set it to define the maximum number of tokens
|
25 |
+
that can be generated.
|
26 |
+
|
27 |
+
😱 Caveats
|
28 |
+
|
29 |
+
1. Allowing a longer output doesn't necessarily mean that the model will generate longer outputs. By default,
|
30 |
+
the model will stop generating when it generates a special `eos_token_id` token.
|
31 |
+
2. You shouldn't set `max_new_tokens` to a value larger than the maximum sequence length of the model. If you need a
|
32 |
+
longer output, consider using a model with a larger maximum sequence length.
|
33 |
+
3. The longer the output, the longer it will take to generate.
|
34 |
+
_________________
|
35 |
+
</details>
|
36 |
+
"""
|
37 |
+
|
38 |
+
SET_MIN_LENGTH = """
|
39 |
+
<details><summary>{match_emoji} {count}. Force a minimum output length.</summary>
|
40 |
+
|
41 |
+
|
42 |
+
🤔 Why?
|
43 |
+
|
44 |
+
Text generation stops when the model generates a special `eos_token_id`. If you prevent it from happening, the model is
|
45 |
+
forced to continue generating.
|
46 |
+
|
47 |
+
🤗 How?
|
48 |
+
|
49 |
+
Our text generation interfaces accept a `min_new_tokens` argument. Set it to prevent `eos_token_id` from being
|
50 |
+
generated until `min_new_tokens` tokens are generated.
|
51 |
+
|
52 |
+
😱 Caveats
|
53 |
+
|
54 |
+
1. The quality of the output may suffer if the model is forced to generate beyond its own original expectations.
|
55 |
+
2. `min_new_tokens` must be smaller than than `max_new_tokens` (see related tip).
|
56 |
+
_________________
|
57 |
+
</details>
|
58 |
+
"""
|
59 |
+
|
60 |
+
REMOVE_EOS_TOKEN = """
|
61 |
+
<details><summary>{match_emoji} {count}. Force the model to generate until it reaches the maximum output length.</summary>
|
62 |
+
|
63 |
+
|
64 |
+
🤔 Why?
|
65 |
+
|
66 |
+
Text generation stops when the model generates a special `eos_token_id`. If there is no `eos_token_id`, the model can't
|
67 |
+
stop.
|
68 |
+
|
69 |
+
|
70 |
+
🤗 How?
|
71 |
+
|
72 |
+
Our text generation interfaces accept a `eos_token_id` argument. Set it to a null value (e.g., in Python,
|
73 |
+
`eos_token_id=None`) to prevent generation to stop before it reaches other stopping conditions.
|
74 |
+
|
75 |
+
😱 Caveats
|
76 |
+
|
77 |
+
1. The quality of the output may suffer if the model is forced to generate beyond its own original expectations.
|
78 |
+
_________________
|
79 |
+
</details>
|
80 |
+
"""
|
81 |
+
|
82 |
+
LIST_EOS_TOKEN = """
|
83 |
+
<details><summary>{match_emoji} {count}. Add a stop word.</summary>
|
84 |
+
|
85 |
+
|
86 |
+
🤔 Why?
|
87 |
+
|
88 |
+
Text generation stops when the model generates a special `eos_token_id`. Actually, this attribute can be a list of
|
89 |
+
tokens, which means you can define arbitrary stop words.
|
90 |
+
|
91 |
+
|
92 |
+
🤗 How?
|
93 |
+
|
94 |
+
Our text generation interfaces accept a `eos_token_id` argument. You can pass a list of tokens to make generation
|
95 |
+
stop in the presence of any of those tokens.
|
96 |
+
|
97 |
+
😱 Caveats
|
98 |
+
|
99 |
+
1. When passing a list of tokens, you probably shouldn't forget to include the default `eos_token_id` there.
|
100 |
+
_________________
|
101 |
+
</details>
|
102 |
+
"""
|
103 |
+
|
104 |
+
TRY_CONTRASTIVE_SEARCH = """
|
105 |
+
<details><summary>{match_emoji} {count}. Try Contrastive Search.</summary>
|
106 |
+
|
107 |
+
|
108 |
+
🤔 Why?
|
109 |
+
|
110 |
+
Contrastive Search is a greedy decoding strategy that strikes a balance between picking the best token and avoiding
|
111 |
+
repetition in the representation space. Despite being a greedy decoding strategy, it can also perform well on tasks
|
112 |
+
that require creativity (i.e. Sampling territory). In some models, it greatly reduces the problem of repetition.
|
113 |
+
|
114 |
+
|
115 |
+
🤗 How?
|
116 |
+
|
117 |
+
Our text generation interfaces accept two arguments: `top_k` and `penalty_alpha`. The authors recomment starting with
|
118 |
+
`top_k=4` and `penalty_alpha=0.6`.
|
119 |
+
|
120 |
+
😱 Caveats
|
121 |
+
|
122 |
+
1. Contrastive Search does not work well with all models -- it depends on how distributed their representation spaces
|
123 |
+
are. See [this thread](https://huggingface.co/spaces/joaogante/contrastive_search_generation/discussions/1#63764a108623a4a7954a5be5)
|
124 |
+
for further information.
|
125 |
+
_________________
|
126 |
+
</details>
|
127 |
+
"""
|
128 |
+
|
129 |
+
BLOCK_BAD_WORDS = """
|
130 |
+
<details><summary>{match_emoji} {count}. Prevent certain words from being generated.</summary>
|
131 |
+
|
132 |
+
|
133 |
+
🤔 Why?
|
134 |
+
|
135 |
+
You might want to prevent your model from generating certain tokens, such as swear words.
|
136 |
+
|
137 |
+
|
138 |
+
🤗 How?
|
139 |
+
|
140 |
+
Our text generation interfaces accept a `bad_words_ids` argument. There, you can pass a list of lists, where each
|
141 |
+
inner list contains a forbidden sequence of tokens.
|
142 |
+
Remember that you can get the token IDs for the words you want to block through
|
143 |
+
`bad_word_ids = tokenizer(bad_words, add_prefix_space=True, add_special_tokens=False).input_ids`
|
144 |
+
_________________
|
145 |
+
</details>
|
146 |
+
"""
|
147 |
+
|
148 |
+
GENERAL_SUGGESTIONS = (
|
149 |
+
(("length",), SET_MAX_NEW_TOKENS),
|
150 |
+
(("length",), SET_MIN_LENGTH),
|
151 |
+
(("length",), REMOVE_EOS_TOKEN),
|
152 |
+
(("length",), LIST_EOS_TOKEN),
|
153 |
+
(("quality", "repetitions"), TRY_CONTRASTIVE_SEARCH),
|
154 |
+
(("quality",), BLOCK_BAD_WORDS),
|
155 |
+
)
|
156 |
+
assert all(isinstance(problem_tags, tuple) for problem_tags, _ in GENERAL_SUGGESTIONS)
|
model_suggestions.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a file holding model-specific suggestions.
|
3 |
+
|
4 |
+
How to add a new suggestion:
|
5 |
+
1. Add a new constant at the bottom of the file with your suggestion. Please try to follow the same format as the
|
6 |
+
existing suggestions.
|
7 |
+
2. Add a new entry to the `MODEL_SUGGESTIONS`, with format `(model tag, (problem tags,), suggestion constant)`.
|
8 |
+
a. Make sure the model tag matches the exact same tag as on the Hub (e.g. GPT-J is `gptj`)
|
9 |
+
b. See `app.py` for the existing problem tags.
|
10 |
+
c. Make sure the problem tags are a tuple.
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
GPTJ_USE_SAMPLING = """
|
15 |
+
<details><summary>{match_emoji} {count}. GPT-J - Avoid using Greedy Search and Beam Search.</summary>
|
16 |
+
|
17 |
+
|
18 |
+
🤔 Why?
|
19 |
+
|
20 |
+
According to its creators, "generating without sampling was actually surprisingly suboptimal".
|
21 |
+
|
22 |
+
|
23 |
+
🤗 How?
|
24 |
+
|
25 |
+
Our text generation interfaces accept a `do_sample` argument. Set it to `True` to ensure sampling-based strategies
|
26 |
+
are used.
|
27 |
+
|
28 |
+
💡 Source
|
29 |
+
|
30 |
+
1. [This tweet](https://twitter.com/EricHallahan/status/1627785461723721729) by a core member of EleutherAI, the
|
31 |
+
creator of GPT-J
|
32 |
+
_________________
|
33 |
+
</details>
|
34 |
+
"""
|
35 |
+
|
36 |
+
|
37 |
+
T5_FLOAT16 = """
|
38 |
+
<details><summary>{match_emoji} {count}. T5 - If you're using int8 or float16, make sure you have `transformers>=4.26.1`.</summary>
|
39 |
+
|
40 |
+
|
41 |
+
🤔 Why?
|
42 |
+
|
43 |
+
In a nutshell, some layers in T5 don't work well in lower precision unless they are in bf16. Newer versions of
|
44 |
+
`transformers` take care of upcasting the layers when needed.
|
45 |
+
|
46 |
+
|
47 |
+
🤗 How?
|
48 |
+
|
49 |
+
Make sure the dependencies in your workflow have `transformers>=4.26.1`
|
50 |
+
|
51 |
+
💡 Source
|
52 |
+
|
53 |
+
1. See [this thread](https://github.com/huggingface/transformers/issues/20287) for the full discussion.
|
54 |
+
_________________
|
55 |
+
</details>
|
56 |
+
"""
|
57 |
+
|
58 |
+
MODEL_SUGGESTIONS = (
|
59 |
+
("gptj", ("quality",), GPTJ_USE_SAMPLING),
|
60 |
+
("t5", ("quality", "baseline", "speed"), T5_FLOAT16),
|
61 |
+
)
|
62 |
+
assert all(isinstance(problem_tags, tuple) for _, problem_tags, _ in MODEL_SUGGESTIONS)
|
task_suggestions.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This is a file holding task-specific suggestions.
|
3 |
+
|
4 |
+
How to add a new suggestion:
|
5 |
+
1. Add a new constant at the bottom of the file with your suggestion. Please try to follow the same format as the
|
6 |
+
existing suggestions.
|
7 |
+
2. Add a new entry to the `TASK_SUGGESTIONS`, with format `((task tags,), (problem tag,), suggestion constant)`.
|
8 |
+
a. See `app.py` for the existing task tags.
|
9 |
+
b. See `app.py` for the existing problem tags.
|
10 |
+
c. Make sure the task tags and the problem tags are a tuple.
|
11 |
+
"""
|
12 |
+
|
13 |
+
USE_SAMPLING = """
|
14 |
+
<details><summary>{match_emoji} {count}. Use sampling decoding strategies.</summary>
|
15 |
+
|
16 |
+
|
17 |
+
🤔 Why?
|
18 |
+
|
19 |
+
The selected task benefits from creativity. Sampling-based decoding strategies typically yield better results in
|
20 |
+
creativity-based tasks
|
21 |
+
|
22 |
+
|
23 |
+
🤗 How?
|
24 |
+
|
25 |
+
Our text generation interfaces accept a `do_sample` argument. Set it to `True` to ensure sampling-based strategies
|
26 |
+
are used.
|
27 |
+
_________________
|
28 |
+
</details>
|
29 |
+
"""
|
30 |
+
|
31 |
+
|
32 |
+
BLOCK_LOW_PROBA_TOKENS = """
|
33 |
+
<details><summary>{match_emoji} {count}. Sampling - Block low probability tokens.</summary>
|
34 |
+
|
35 |
+
|
36 |
+
🤔 Why?
|
37 |
+
|
38 |
+
When decoding with sampling-based strategies, ANY token in the model vocabulary can be selected. This means there is
|
39 |
+
always a chance that the generated text drifts off-topic.
|
40 |
+
|
41 |
+
|
42 |
+
🤗 How?
|
43 |
+
|
44 |
+
There are a few different strategies you can try. They all discard (i.e. set probability to 0) tokens at each
|
45 |
+
generation step. Unless stated otherwise, they can be used with each other.
|
46 |
+
1. Top K: discards all but the K most likely tokens (suggeted value: `top_k=50`);
|
47 |
+
2. Top P: sorts tokens by probability in descending order, computes the cumulative probability, then discards all
|
48 |
+
tokens with a cumulative probability above the threshold P (suggested value: `top_p=0.95`);
|
49 |
+
3. ETA Cutoff: A balance between Top K and Top P. See the [corresponding paper](https://arxiv.org/abs/2210.15191) for
|
50 |
+
more details (should not be used with others; suggested value: `eta_cutoff=1e-3`).
|
51 |
+
_________________
|
52 |
+
</details>
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
USE_DETERMINISTIC = """
|
57 |
+
<details><summary>{match_emoji} {count}. Use greedy decoding strategies.</summary>
|
58 |
+
|
59 |
+
|
60 |
+
🤔 Why?
|
61 |
+
|
62 |
+
The selected task is factual, it does not benefit from creativity. Greedy decoding strategies (like Greedy Search and
|
63 |
+
Beam Search) are preferred in those situations.
|
64 |
+
|
65 |
+
|
66 |
+
🤗 How?
|
67 |
+
|
68 |
+
Our text generation interfaces accept a `do_sample` argument. Set it to `False` to ensure greedy strategies
|
69 |
+
are used.
|
70 |
+
_________________
|
71 |
+
</details>
|
72 |
+
"""
|
73 |
+
|
74 |
+
# task tags that should use sampling and benefit from sampling-related advice
|
75 |
+
sampling = ("txtgen", "chat", "img2txt")
|
76 |
+
# task tags that should NOT use sampling and benefit from greedy/beam search advice
|
77 |
+
greedy = ("summ", "trans", "txtqa", "otherqa", "asr")
|
78 |
+
|
79 |
+
TASK_SUGGESTIONS = (
|
80 |
+
(sampling, ("quality",), USE_SAMPLING),
|
81 |
+
(sampling, ("quality", "hallucinations"), BLOCK_LOW_PROBA_TOKENS),
|
82 |
+
(greedy, ("quality",), USE_DETERMINISTIC),
|
83 |
+
)
|
84 |
+
assert all(isinstance(problem_tags, tuple) for _, problem_tags, _ in TASK_SUGGESTIONS)
|
85 |
+
assert all(isinstance(task_tags, tuple) for task_tags, _, _ in TASK_SUGGESTIONS)
|