Spaces:
Runtime error
Runtime error
import gradio as gr | |
import huggingface_hub as hfh | |
from requests.exceptions import HTTPError | |
from functools import lru_cache | |
from general_suggestions import GENERAL_SUGGESTIONS | |
from model_suggestions import MODEL_SUGGESTIONS | |
from task_suggestions import TASK_SUGGESTIONS | |
# ===================================================================================================================== | |
# DATA | |
# ===================================================================================================================== | |
# Dict with the tasks considered in this spaces, {task: task tag} | |
TASK_TYPES = { | |
"✍️ Text Generation": "txtgen", | |
"🤏 Summarization": "summ", | |
"🫂 Translation": "trans", | |
"💬 Conversational / Chatbot": "chat", | |
"🤷 Text Question Answering": "txtqa", | |
"🕵️ (Table/Document/Visual) Question Answering": "otherqa", | |
"🎤 Automatic Speech Recognition": "asr", | |
"🌇 Image to Text": "img2txt", | |
} | |
# Dict matching all task types with their possible hub tags, {task tag: (possible, hub, tags)} | |
HUB_TAGS = { | |
"txtgen": ("text-generation", "text2text-generation"), | |
"summ": ("summarization", "text-generation", "text2text-generation"), | |
"trans": ("translation", "text-generation", "text2text-generation"), | |
"chat": ("conversational", "text-generation", "text2text-generation"), | |
"txtqa": ("text-generation", "text2text-generation"), | |
"otherqa": ("table-question-answering", "document-question-answering", "visual-question-answering"), | |
"asr": ("automatic-speech-recognition",), | |
"img2txt": ("image-to-text",), | |
} | |
assert len(TASK_TYPES) == len(TASK_TYPES) | |
assert all(tag in HUB_TAGS for tag in TASK_TYPES.values()) | |
# Dict with the problems considered in this spaces, {problem: problem tag} | |
PROBLEMS = { | |
"🤔 Baseline. I'm getting gibberish and I want a baseline": "baseline", | |
"😵 Crashes. I want to prevent my model from crashing again": "crashes", | |
"🤥 Hallucinations. I would like to reduce them": "hallucinations", | |
"📏 Length. I want to control the length of the output": "length", | |
"🌈 Prompting. I want better outputs without changing my generation options": "prompting", | |
"😵💫 Repetitions. Make them stop make them stop": "repetitions", | |
"📈 Quality. I want better outputs without changing my prompt": "quality", | |
"🏎 Speed! Make it faster!": "speed", | |
} | |
INIT_MARKDOWN = """ | |
| |
👈 Fill in as much information as you can... | |
| |
| |
| |
| |
| |
| |
👈 ... then click here! | |
""" | |
DEMO_MARKDOWN = """ | |
⛔️ This is still a demo 🤗 Working sections include "Length" and "Quality" ⛔️ | |
""" | |
MODEL_PROBLEM = """ | |
😱 Could not retrieve model tags for the specified model, `{model_name}`. Ensure that the model name matches a Hub | |
model repo, that it is a public model, and that it has Hub tags. | |
""" | |
SUGGETIONS_HEADER = """ | |
#### ✨ Here is a list of suggestions for you -- click to expand ✨ | |
""" | |
PERFECT_MATCH_EMOJI = "✅" | |
POSSIBLE_MATCH_EMOJI = "❓" | |
MISSING_INPUTS = """ | |
💡 You can filter suggestions with {} if you add more inputs. Suggestions with {} are a perfect match. | |
""".format(POSSIBLE_MATCH_EMOJI, PERFECT_MATCH_EMOJI) | |
# The space below is reserved for suggestions that require advanced logic and/or formatting | |
TASK_MODEL_MISMATCH = """ | |
<details><summary>{count}. Select a model better suited for your task.</summary> | |
| |
🤔 Why? | |
The selected model (`{model_name}`) doesn't have a tag compatible with the task you selected ("{task_type}"). | |
Expected tags for this task are: {tags} | |
🤗 How? | |
Our recommendation is to go to our [tasks page](https://huggingface.co/tasks) and select one of the suggested | |
models as a starting point. | |
😱 Caveats | |
1. The tags of a model are defined by the community and are not always accurate. If you think the model is incorrectly | |
tagged or missing a tag, please open an issue on the [model card](https://huggingface.co/{model_name}/tree/main). | |
_________________ | |
</details> | |
""" | |
# ===================================================================================================================== | |
# ===================================================================================================================== | |
# SUGGESTIONS LOGIC | |
# ===================================================================================================================== | |
def is_valid_task_for_model(model_tags, user_task): | |
if len(model_tags) == 0 or user_task == "": | |
return True # No model / no task tag = no problem :) | |
possible_tags = HUB_TAGS[user_task] | |
return any(tag in model_tags for tag in possible_tags) | |
def get_model_tags(model_name): | |
if model_name == "": | |
return [] | |
try: | |
model_tags = hfh.HfApi().model_info(model_name).tags | |
except HTTPError: | |
model_tags = [] | |
return model_tags | |
def get_suggestions(task_type, model_name, problem_type): | |
# Check if the inputs were given | |
if all([task_type == "", model_name == "", problem_type == ""]): | |
return INIT_MARKDOWN | |
suggestions = "" | |
counter = 0 | |
model_tags = get_model_tags(model_name) | |
# If there is a model name but no model tags, something went wrong | |
if model_name != "" and len(model_tags) == 0: | |
return MODEL_PROBLEM.format(model_name=model_name) | |
user_problem = PROBLEMS.get(problem_type, "") | |
user_task = TASK_TYPES.get(task_type, "") | |
# Check if the model is valid for the task. If not, return straight away | |
if not is_valid_task_for_model(model_tags, user_task): | |
counter += 1 | |
possible_tags = " ".join("`" + tag + "`" for tag in HUB_TAGS[user_task]) | |
suggestions += TASK_MODEL_MISMATCH.format( | |
count=counter, model_name=model_name, task_type=user_task, tags=possible_tags | |
) | |
return suggestions | |
# Demo shortcut: only a few sections are working | |
if user_problem not in ("", "length", "quality"): | |
return DEMO_MARKDOWN | |
# First: model-specific suggestions | |
has_model_specific_suggestions = False | |
match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or len(model_tags) == 0) else PERFECT_MATCH_EMOJI | |
for model_tag, problem_tags, suggestion in MODEL_SUGGESTIONS: | |
if user_problem == "" or user_problem in problem_tags: | |
if len(model_tags) == 0 or model_tag in model_tags: | |
counter += 1 | |
suggestions += suggestion.format(count=counter, match_emoji=match_emoji) | |
has_model_specific_suggestions = True | |
# Second: task-specific suggestions | |
has_task_specific_suggestions = False | |
match_emoji = POSSIBLE_MATCH_EMOJI if (user_problem == "" or user_task == "") else PERFECT_MATCH_EMOJI | |
for task_tags, problem_tags, suggestion in TASK_SUGGESTIONS: | |
if user_problem == "" or user_problem in problem_tags: | |
if user_task == "" or user_task in task_tags: | |
counter += 1 | |
suggestions += suggestion.format(count=counter, match_emoji=match_emoji) | |
has_task_specific_suggestions = True | |
# Finally: general suggestions for the problem | |
has_problem_specific_suggestions = False | |
match_emoji = POSSIBLE_MATCH_EMOJI if user_problem == "" else PERFECT_MATCH_EMOJI | |
for problem_tags, suggestion in GENERAL_SUGGESTIONS: | |
if user_problem == "" or user_problem in problem_tags: | |
counter += 1 | |
suggestions += suggestion.format(count=counter, match_emoji=match_emoji) | |
has_problem_specific_suggestions = True | |
# Prepends needed bits | |
if ( | |
(task_type == "" and has_task_specific_suggestions) | |
or (model_name == "" and has_model_specific_suggestions) | |
or (problem_type == "" and has_problem_specific_suggestions) | |
): | |
suggestions = MISSING_INPUTS + suggestions | |
return SUGGETIONS_HEADER + suggestions | |
# ===================================================================================================================== | |
# ===================================================================================================================== | |
# GRADIO | |
# ===================================================================================================================== | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# 🚀💬 Improving Generated Text 💬🚀 | |
This is a ever-evolving guide on how to improve your text generation results. It is community-led and | |
curated by Hugging Face 🤗 | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
problem_type = gr.Dropdown( | |
label="What would you like to improve?", | |
choices=[""] + list(PROBLEMS.keys()), | |
interactive=True, | |
value="", | |
) | |
task_type = gr.Dropdown( | |
label="Which task are you working on?", | |
choices=[""] + list(TASK_TYPES.keys()), | |
interactive=True, | |
value="", | |
) | |
model_name = gr.Textbox( | |
label="Which model are you using?", | |
placeholder="e.g. google/flan-t5-xl", | |
interactive=True, | |
) | |
button = gr.Button(value="Get Suggestions!") | |
with gr.Column(scale=2): | |
suggestions = gr.Markdown(value=INIT_MARKDOWN) | |
button.click(get_suggestions, inputs=[task_type, model_name, problem_type], outputs=suggestions) | |
gr.Markdown( | |
""" | |
| |
Is your problem not on the list? Need more suggestions? Have you spotted an error? Please open a | |
[new discussion](https://huggingface.co/spaces/joaogante/generate_quality_improvement/discussions) 🙏 | |
""" | |
) | |
# ===================================================================================================================== | |
if __name__ == "__main__": | |
demo.launch() | |