Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import os # utility library | |
| # libraries to load the model and serve inference | |
| import tensorflow_text | |
| import tensorflow as tf | |
| CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded | |
| if not os.path.exists(CACHE_DIR): | |
| os.mkdir(CACHE_DIR) | |
| # download the files from huggingface repo and load the model with tensorflow | |
| snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR) | |
| saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0]) | |
| model = tf.saved_model.load(saved_model_path, ["serve"]) | |
| title = "Interactive demo: T5 Multitasking Demo" | |
| description = "Demo for T5's different tasks including machine translation, \ | |
| text summarization, document similarity, and grammatical correctness of sentences." | |
| def predict_fn(x): | |
| """Function to get inferences from model on live data points. | |
| params: | |
| x input text to run get output on | |
| returns: | |
| a numpy array representing the output | |
| """ | |
| return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy() | |
| def predict(task_type, sentence): | |
| """Function to parse the user inputs, run the parsed text through the | |
| model and return output in a readable format. | |
| params: | |
| task_type sentence representing the type of task to run on T5 model | |
| sentence sentence to get inference on | |
| returns: | |
| text decoded into a human-readable format. | |
| """ | |
| task_dict = { | |
| "Translate English to French": "Translate English to French", | |
| "Translate English to German": "Translate English to German", | |
| "Translate English to Romanian": "Translate English to Romanian", | |
| "Grammatical Correctness of Sentence": "cola sentence", | |
| "Text Summarization": "summarize", | |
| "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)": "stsb", | |
| } | |
| question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5 | |
| # Document Similarity takes in two sentences so it has to be parsed in a separate manner | |
| if task_type.startswith("Document Similarity"): | |
| sentences = sentence.split('---') | |
| question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}" | |
| return predict_fn([question])[0].decode('utf-8') | |
| iface = gr.Interface(fn=predict, | |
| inputs=[gr.inputs.Radio( | |
| choices=["Translate English to French", | |
| "Translate English to German", | |
| "Translate English to Romanian", | |
| "Grammatical Correctness of Sentence", | |
| "Text Summarization", | |
| "Document Similarity Score (separate the 2 sentences with 3 dashes `---`)"], | |
| label="Task Type" | |
| ), | |
| gr.inputs.Textbox(label="Sentence")], | |
| outputs="text", | |
| title=title, | |
| description=description, | |
| examples=[ | |
| ["Translate English to French", "I am Steven and I live in Lagos, Nigeria"], | |
| ["Translate English to German", "I am Steven and I live in Lagos, Nigeria"], | |
| ["Translate English to Romanian", "I am Steven and I live in Lagos, Nigeria"], | |
| ["Grammatical Correctness of Sentence", "I am Steven and I live in Lagos, Nigeria"], | |
| ["Text Summarization", | |
| "I don't care about those doing the comparison, but comparing the Ghanaian Jollof Rice to \ | |
| Nigerian Jollof Rice is an insult to Nigerians"], | |
| ["Document Similarity Score (separate the 2 sentences with 3 dashes `---`)", | |
| "I reside in the commercial capital city of Nigeria, which is Lagos---I live in Lagos"], | |
| ], | |
| theme="huggingface", | |
| enable_queue=True) | |
| iface.launch() | |