File size: 10,001 Bytes
93a8576
03844f7
 
88d43b3
 
 
 
 
93a8576
03844f7
 
 
88d43b3
03844f7
 
 
 
 
 
 
 
 
 
93a8576
88d43b3
03844f7
 
 
 
 
 
 
d9f5161
 
03844f7
 
 
 
88d43b3
03844f7
88d43b3
84ce1c5
88d43b3
e6473c1
88d43b3
 
 
e6473c1
03844f7
 
5a8ffa4
90f9b30
 
d9f5161
90f9b30
 
 
28bb13c
bfee484
28bb13c
bfee484
28bb13c
bfee484
28bb13c
bfee484
28bb13c
bfee484
383f0d7
5a8ffa4
f59b89f
30d343c
66e8a8e
30d343c
 
6244f43
 
 
 
 
cf403e7
88d43b3
30d343c
cf403e7
88d43b3
 
 
 
 
 
 
a591609
87ac548
28bb13c
bfee484
28bb13c
bfee484
d9f5161
 
bfee484
28bb13c
bfee484
d9f5161
bfee484
 
28bb13c
bfee484
d9f5161
a312531
d9f5161
88d43b3
d9f5161
 
 
03844f7
 
 
 
 
 
88d43b3
 
6244f43
88d43b3
 
 
 
 
 
 
03844f7
88d43b3
03844f7
 
 
88d43b3
 
15a9b76
 
08f9035
15a9b76
30d343c
88d43b3
5a8ffa4
a4fa1bd
88d43b3
a4fa1bd
88d43b3
 
6244f43
 
 
 
88d43b3
 
cf403e7
30d343c
88d43b3
a4fa1bd
88d43b3
d9f5161
88d43b3
d9f5161
cf403e7
 
30d343c
88d43b3
30d343c
 
88d43b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03844f7
 
 
 
 
 
cdf3b28
 
 
 
cf403e7
cdf3b28
d9ee818
 
cdf3b28
 
 
d9ee818
2b2223b
30d343c
a4fa1bd
30d343c
 
 
 
cdf3b28
88d43b3
30d343c
5265793
30d343c
d9ee818
 
a4fa1bd
03844f7
5265793
d9ee818
15a9b76
03844f7
5a8ffa4
15a9b76
 
 
383f0d7
 
b258678
 
383f0d7
 
 
 
 
cdf3b28
03844f7
5265793
cdf3b28
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import gradio as gr
import huggingface_hub as hfh
from requests.exceptions import HTTPError
from functools import lru_cache

from general_suggestions import GENERAL_SUGGESTIONS
from model_suggestions import MODEL_SUGGESTIONS
from task_suggestions import TASK_SUGGESTIONS

# =====================================================================================================================
# DATA
# =====================================================================================================================
# Dict with the tasks considered in this spaces, {task: task tag}
TASK_TYPES = {
    "✍️ Text Generation": "txtgen",
    "🤏 Summarization": "summ",
    "🫂 Translation": "trans",
    "💬 Conversational / Chatbot": "chat",
    "🤷 Text Question Answering": "txtqa",
    "🕵️ (Table/Document/Visual) Question Answering": "otherqa",
    "🎤 Automatic Speech Recognition": "asr",
    "🌇 Image to Text": "img2txt",
}

# Dict matching all task types with their possible hub tags, {task tag: (possible, hub, tags)}
HUB_TAGS = {
    "txtgen": ("text-generation", "text2text-generation"),
    "summ": ("summarization", "text-generation", "text2text-generation"),
    "trans": ("translation", "text-generation", "text2text-generation"),
    "chat": ("conversational", "text-generation", "text2text-generation"),
    "txtqa": ("text-generation", "text2text-generation"),
    "otherqa": ("table-question-answering", "document-question-answering", "visual-question-answering"),
    "asr": ("automatic-speech-recognition",),
    "img2txt": ("image-to-text",),
}
assert len(TASK_TYPES) == len(TASK_TYPES)
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values())

# Dict with the problems considered in this spaces, {problem: problem tag}
PROBLEMS = {
    "🤔 Baseline. I'm getting gibberish and I want a baseline": "baseline",
    "😵 Crashes. I want to prevent my model from crashing again": "crashes",
    "🤥 Hallucinations. I would like to reduce them": "hallucinations",
    "📏 Length. I want to control the length of the output": "length",
    "🌈 Prompting. I want better outputs without changing my generation options": "prompting",
    "😵‍💫 Repetitions. Make them stop make them stop": "repetitions",
    "📈 Quality. I want better outputs without changing my prompt": "quality",
    "🏎 Speed! Make it faster!": "speed",
}

INIT_MARKDOWN = """
 

👈 Fill in as much information as you can...

 

 

 

 

 

 

👈 ... then click here!
"""

DEMO_MARKDOWN = """
⛔️ This is still a demo 🤗 Working sections include "Length" and "Quality" ⛔️
"""

MODEL_PROBLEM = """
😱 Could not retrieve model tags for the specified model, `{model_name}`. Ensure that the model name matches a Hub
model repo, that it is a public model, and that it has Hub tags.
"""

SUGGETIONS_HEADER = """
#### ✨ Here is a list of suggestions for you -- click to expand ✨
"""

PERFECT_MATCH_EMOJI = "✅"
POSSIBLE_MATCH_EMOJI = "❓"
MISSING_INPUTS = """
💡 You can filter suggestions with {} if you add more inputs. Suggestions with {} are a perfect match.
""".format(POSSIBLE_MATCH_EMOJI, PERFECT_MATCH_EMOJI)

# The space below is reserved for suggestions that require advanced logic and/or formatting
TASK_MODEL_MISMATCH = """
<details><summary>{count}. Select a model better suited for your task.</summary>
&nbsp;

🤔 Why? &nbsp;

The selected model (`{model_name}`) doesn't have a tag compatible with the task you selected ("{task_type}").
Expected tags for this task are: {tags} &nbsp;

🤗 How? &nbsp;

Our recommendation is to go to our [tasks page](https://huggingface.co/tasks) and select one of the suggested
models as a starting point. &nbsp;

😱 Caveats &nbsp;

1. The tags of a model are defined by the community and are not always accurate. If you think the model is incorrectly
tagged or missing a tag, please open an issue on the [model card](https://huggingface.co/{model_name}/tree/main).

_________________

</details>
"""
# =====================================================================================================================


# =====================================================================================================================
# SUGGESTIONS LOGIC
# =====================================================================================================================
def is_valid_task_for_model(model_tags, user_task):
    if len(model_tags) == 0 or user_task == "":
        return True  # No model / no task tag = no problem :)

    possible_tags = HUB_TAGS[user_task]
    return any(tag in model_tags for tag in possible_tags)


@lru_cache(maxsize=int(2e10))
def get_model_tags(model_name):
    if model_name == "":
        return []
    try:
        model_tags = hfh.HfApi().model_info(model_name).tags
    except HTTPError:
        model_tags = []
    return model_tags


@lru_cache(maxsize=int(2e10))
def get_suggestions(task_type, model_name, problem_type):
    # Check if the inputs were given
    if all([task_type == "", model_name == "", problem_type == ""]):
        return INIT_MARKDOWN

    suggestions = ""
    counter = 0
    model_tags = get_model_tags(model_name)

    # If there is a model name but no model tags, something went wrong
    if model_name != "" and len(model_tags) == 0:
        return MODEL_PROBLEM.format(model_name=model_name)

    user_problem = PROBLEMS.get(problem_type, "")
    user_task = TASK_TYPES.get(task_type, "")

    # Check if the model is valid for the task. If not, return straight away
    if not is_valid_task_for_model(model_tags, user_task):
        counter += 1
        possible_tags = " ".join("`" + tag + "`" for tag in HUB_TAGS[user_task])
        suggestions += TASK_MODEL_MISMATCH.format(
            count=counter, model_name=model_name, task_type=user_task, tags=possible_tags
        )
        return suggestions

    # Demo shortcut: only a few sections are working
    if user_problem not in ("", "length", "quality"):
        return DEMO_MARKDOWN

    # First: model-specific suggestions
    has_model_specific_suggestions = False
    match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or len(model_tags) == 0) else PERFECT_MATCH_EMOJI
    for model_tag, problem_tags, suggestion in MODEL_SUGGESTIONS:
        if user_problem == "" or user_problem in problem_tags:
            if len(model_tags) == 0 or model_tag in model_tags:
                counter += 1
                suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
                has_model_specific_suggestions = True

    # Second: task-specific suggestions
    has_task_specific_suggestions = False
    match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or user_task == "") else PERFECT_MATCH_EMOJI
    for task_tags, problem_tags, suggestion in TASK_SUGGESTIONS:
        if user_problem == "" or user_problem in problem_tags:
            if user_task == "" or user_task in task_tags:
                counter += 1
                suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
                has_task_specific_suggestions = True

    # Finally: general suggestions for the problem
    has_problem_specific_suggestions = False
    match_emoji = POSSIBLE_MATCH_EMOJI if user_problem == "" else PERFECT_MATCH_EMOJI
    for problem_tags, suggestion in GENERAL_SUGGESTIONS:
        if user_problem == "" or user_problem in problem_tags:
            counter += 1
            suggestions += suggestion.format(count=counter, match_emoji=match_emoji)
            has_problem_specific_suggestions = True

    # Prepends needed bits
    if (
        (task_type == "" and has_task_specific_suggestions)
        or (model_name == "" and has_model_specific_suggestions)
        or (problem_type == "" and has_problem_specific_suggestions)
    ):
        suggestions = MISSING_INPUTS + suggestions

    return SUGGETIONS_HEADER + suggestions
# =====================================================================================================================


# =====================================================================================================================
# GRADIO
# =====================================================================================================================
demo = gr.Blocks()
with demo:
    gr.Markdown(
        """
        # 🚀💬 Improving Generated Text 💬🚀

        This is a ever-evolving guide on how to improve your text generation results. It is community-led and
        curated by Hugging Face 🤗
        """
    )

    with gr.Row():
        with gr.Column():
            problem_type = gr.Dropdown(
                label="What would you like to improve?",
                choices=[""] + list(PROBLEMS.keys()),
                interactive=True,
                value="",
            )
            task_type = gr.Dropdown(
                label="Which task are you working on?",
                choices=[""] + list(TASK_TYPES.keys()),
                interactive=True,
                value="",
            )
            model_name = gr.Textbox(
                label="Which model are you using?",
                placeholder="e.g. google/flan-t5-xl",
                interactive=True,
            )
            button = gr.Button(value="Get Suggestions!")
        with gr.Column(scale=2):
            suggestions = gr.Markdown(value=INIT_MARKDOWN)

    button.click(get_suggestions, inputs=[task_type, model_name, problem_type], outputs=suggestions)

    gr.Markdown(
        """
        &nbsp;

        Is your problem not on the list? Need more suggestions? Have you spotted an error? Please open a
        [new discussion](https://huggingface.co/spaces/joaogante/generate_quality_improvement/discussions) 🙏
        """
    )


# =====================================================================================================================

if __name__ == "__main__":
    demo.launch()