sanchit-gandhi's picture
Better error message when unable to load space (#3)
c822b09
import os
import gradio as gr
import soundfile as sf
import torch
from gradio_client import Client
from huggingface_hub import Repository
from pandas import read_csv
from transformers import pipeline
# load the results file from the private repo
USERNAMES_DATASET_ID = "huggingface-course/audio-course-u7-hands-on"
HF_TOKEN = os.environ.get("HF_TOKEN")
usernames_url = os.path.join("https://huggingface.co/datasets", USERNAMES_DATASET_ID)
usernames_repo = Repository(local_dir="usernames", clone_from=usernames_url, use_auth_token=HF_TOKEN)
usernames_repo.git_pull()
CSV_RESULTS_FILE = os.path.join("usernames", "usernames.csv")
all_results = read_csv(CSV_RESULTS_FILE)
# load the LID checkpoint
device = "cuda:0" if torch.cuda.is_available() else "cpu"
pipe = pipeline("audio-classification", model="facebook/mms-lid-126", device=device)
# define some constants
TITLE = "πŸ€— Audio Transformers Course: Unit 7 Assessment"
DESCRIPTION = """
Check that you have successfully completed the hands-on exercise for Unit 7 of the πŸ€— Audio Transformers Course by submitting your demo to this Space.
As a reminder, you should start with the template Space provided at [`course-demos/speech-to-speech-translation`](https://huggingface.co/spaces/course-demos/speech-to-speech-translation),
and update the Space to translate from any language X to a **non-English** language Y. Your demo should take as input an audio file, and return as output another audio file,
matching the signature of the [`speech_to_speech_translation`](https://huggingface.co/spaces/course-demos/speech-to-speech-translation/blob/3946ba6705a6632a63de8672ac52a482ab74b3fc/app.py#L35)
function in the template demo.
To submit your demo for assessment, give the repo id or URL to your demo. For the template demo, this would be `course-demos/speech-to-speech-translation`.
You should ensure that the visibility of your demo is set to **public**. This Space will submit a test file to your demo, and check that the output is
non-English audio. If your demo successfully returns an audio file, and this audio file is classified as being non-English, you will pass the Unit and
get a green tick next to your name on the overall [course progress space](https://huggingface.co/spaces/MariaK/Check-my-progress-Audio-Course) βœ…
If you experience any issues with using this checker, [open an issue](https://huggingface.co/spaces/huggingface-course/audio-course-u7-assessment/discussions/new)
on this Space and tag [`@sanchit-gandhi`](https://huggingface.co/sanchit-gandhi).
"""
THRESHOLD = 0.5
PASS_MESSAGE = "Congratulations USER! Your demo passed the assessment!"
def verify_demo(repo_id):
if "/" not in repo_id:
raise gr.Error(f"Ensure you pass a valid repo id to the assessor, got `{repo_id}`")
split_repo_id = repo_id.split("/")
user_name = split_repo_id[-2]
if len(split_repo_id) > 2:
repo_id = "/".join(split_repo_id[-2:])
if (all_results["username"] == user_name).any():
raise gr.Error(f"Username {user_name} has already passed the assessment!")
try:
client = Client(repo_id, hf_token=HF_TOKEN)
except Exception as e:
raise gr.Error("Error with loading Space. First check that your Space has been built and is running."
"Then check that your Space takes an audio file as input and returns an audio as output. If it is working"
f"as expected and the error persists, open an issue on this Space. Error: {e}"
)
try:
audio_file = client.predict("test_short.wav", api_name="/predict")
except Exception as e:
raise gr.Error(
f"Error with querying Space, check that your Space takes an audio file as input and returns an audio as output: {e}"
)
audio, sampling_rate = sf.read(audio_file)
language_prediction = pipe({"array": audio, "sampling_rate": sampling_rate})
label_outputs = {}
for pred in language_prediction:
label_outputs[pred["label"]] = pred["score"]
top_prediction = language_prediction[0]
if top_prediction["score"] < THRESHOLD:
raise gr.Error(
f"Model made random predictions - predicted {top_prediction['label']} with probability {top_prediction['score']}"
)
elif top_prediction["label"] == "eng":
raise gr.Error(
"Model generated an English audio - ensure the model is set to generate audio in a non-English langauge, e.g. Dutch"
)
# save and upload new evaluated usernames
all_results.loc[len(all_results)] = {"username": user_name}
all_results.to_csv(CSV_RESULTS_FILE, index=False)
usernames_repo.push_to_hub()
message = PASS_MESSAGE.replace("USER", user_name)
return message, "test_short.wav", (sampling_rate, audio), label_outputs
demo = gr.Interface(
fn=verify_demo,
inputs=gr.Textbox(placeholder="course-demos/speech-to-speech-translation", label="Repo id or URL of your demo"),
outputs=[
gr.Textbox(label="Status"),
gr.Audio(label="Source Speech", type="filepath"),
gr.Audio(label="Generated Speech", type="numpy"),
gr.Label(label="Language prediction"),
],
title=TITLE,
description=DESCRIPTION,
)
demo.launch()