Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
# Define the necessary pipelines | |
def load_qa_model(): | |
return pipeline("question-answering", model="bert-large-uncased-whole-word-masking-finetuned-squad") | |
def load_classifier_model(): | |
return pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33") | |
def load_translator_model(target_language): | |
try: | |
model_name = f"Helsinki-NLP/opus-mt-en-{target_language}" | |
return pipeline("translation", model=model_name) | |
except Exception as e: | |
print(f"Error loading translation model: {e}") | |
return None | |
def load_generator_model(): | |
try: | |
return pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B", tokenizer="EleutherAI/gpt-neo-2.7B") | |
except Exception as e: | |
print(f"Error loading text generation model: {e}") | |
return None | |
def load_summarizer_model(): | |
try: | |
return pipeline("summarization", model="facebook/bart-large-cnn") | |
except Exception as e: | |
print(f"Error loading summarization model: {e}") | |
return None | |
# Define the functions for processing | |
def process_qa(context, question): | |
qa_model = load_qa_model() | |
try: | |
return qa_model(context=context, question=question)["answer"] | |
except Exception as e: | |
print(f"Error during question answering: {e}") | |
return "Error during question answering" | |
def process_classifier(text, labels): | |
classifier_model = load_classifier_model() | |
try: | |
return classifier_model(text, labels)["labels"][0] | |
except Exception as e: | |
print(f"Error during classification: {e}") | |
return "Error during classification" | |
def process_translation(text, target_language): | |
translator_model = load_translator_model(target_language) | |
if translator_model: | |
try: | |
return translator_model(text)[0]["translation_text"] | |
except Exception as e: | |
print(f"Error during translation: {e}") | |
return "Translation error" | |
return "Model loading error" | |
def process_generation(prompt): | |
generator_model = load_generator_model() | |
if generator_model: | |
if prompt.strip(): | |
try: | |
return generator_model(prompt, max_length=50)[0]["generated_text"] | |
except Exception as e: | |
print(f"Error during text generation: {e}") | |
return "Text generation error" | |
else: | |
return "Prompt is empty" | |
return "Model loading error" | |
def process_summarization(text): | |
summarizer_model = load_summarizer_model() | |
if summarizer_model: | |
if text.strip(): | |
try: | |
return summarizer_model(text, max_length=150, min_length=40, do_sample=False)[0]["summary_text"] | |
except Exception as e: | |
print(f"Error during summarization: {e}") | |
return "Summarization error" | |
else: | |
return "Text is empty" | |
return "Model loading error" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("Choose an NLP task and input the required text.") | |
with gr.Tab("Single-Models"): | |
gr.Markdown("This tab is for single models demonstration.") | |
task_select_single = gr.Dropdown(["Question Answering", "Zero-Shot Classification", "Translation", "Text Generation", "Summarization"], label="Select Task") | |
input_text_single = gr.Textbox(label="Input Text") | |
# Additional inputs for specific tasks | |
context_input_single = gr.Textbox(label="Context", visible=False) | |
label_input_single = gr.CheckboxGroup(["positive", "negative", "neutral"], label="Labels", visible=False) | |
target_language_input_single = gr.Dropdown(["nl", "fr", "es", "de"], label="Target Language", visible=False) | |
output_text_single = gr.Textbox(label="Output") | |
execute_button_single = gr.Button("Execute") | |
def update_inputs(task): | |
if task == "Question Answering": | |
return { | |
context_input_single: gr.update(visible=True), | |
label_input_single: gr.update(visible=False), | |
target_language_input_single: gr.update(visible=False) | |
} | |
elif task == "Zero-Shot Classification": | |
return { | |
context_input_single: gr.update(visible=False), | |
label_input_single: gr.update(visible=True), | |
target_language_input_single: gr.update(visible=False) | |
} | |
elif task == "Translation": | |
return { | |
context_input_single: gr.update(visible=False), | |
label_input_single: gr.update(visible=False), | |
target_language_input_single: gr.update(visible=True) | |
} | |
elif task == "Text Generation": | |
return { | |
context_input_single: gr.update(visible=False), | |
label_input_single: gr.update(visible=False), | |
target_language_input_single: gr.update(visible=False) | |
} | |
elif task == "Summarization": | |
return { | |
context_input_single: gr.update(visible=False), | |
label_input_single: gr.update(visible=False), | |
target_language_input_single: gr.update(visible=False) | |
} | |
else: | |
return { | |
context_input_single: gr.update(visible=False), | |
label_input_single: gr.update(visible=False), | |
target_language_input_single: gr.update(visible=False) | |
} | |
task_select_single.change(fn=update_inputs, inputs=task_select_single, | |
outputs=[context_input_single, label_input_single, target_language_input_single]) | |
def execute_task_single(task, input_text, context, labels, target_language): | |
if task == "Question Answering": | |
return process_qa(context=context, question=input_text) | |
elif task == "Zero-Shot Classification": | |
if not labels: | |
return "Please provide labels for classification." | |
return process_classifier(text=input_text, labels=labels) | |
elif task == "Translation": | |
if not target_language: | |
return "Please select a target language for translation." | |
return process_translation(text=input_text, target_language=target_language) | |
elif task == "Text Generation": | |
return process_generation(prompt=input_text) | |
elif task == "Summarization": | |
return process_summarization(text=input_text) | |
else: | |
return "Invalid task selected." | |
execute_button_single.click( | |
execute_task_single, | |
inputs=[task_select_single, input_text_single, context_input_single, label_input_single, target_language_input_single], | |
outputs=output_text_single | |
) | |
with gr.Tab("Multi-Model Task"): | |
gr.Markdown("This tab allows you to execute all tasks sequentially.") | |
# Inputs for all tasks | |
input_text_multi = gr.Textbox(label="Input Text") | |
context_input_multi = gr.Textbox(label="Context (for QA)") | |
label_input_multi = gr.CheckboxGroup(["positive", "negative", "neutral"], label="Labels (for Classification)") | |
target_language_input_multi = gr.Dropdown(["nl", "fr", "es", "de"], label="Target Language (for Translation)") | |
# Outputs for all tasks | |
output_qa = gr.Textbox(label="QA Output") | |
output_classification = gr.Textbox(label="Classification Output") | |
output_translation = gr.Textbox(label="Translation Output") | |
output_generation = gr.Textbox(label="Text Generation Output") | |
output_summarization = gr.Textbox(label="Summarization Output") | |
execute_button_multi = gr.Button("Execute All Tasks") | |
def execute_all_tasks(input_text, context, labels, target_language): | |
# Process Question Answering | |
qa_output = process_qa(context=context, question=input_text) | |
# Process Classification | |
classification_output = process_classifier(text=input_text, labels=labels) | |
# Process Translation | |
translation_output = process_translation(text=input_text, target_language=target_language) | |
# Process Text Generation using QA output | |
generation_output = process_generation(prompt=qa_output) | |
# Process Summarization using Text Generation output | |
summarization_output = process_summarization(text=generation_output) | |
# Return all outputs | |
return qa_output, classification_output, translation_output, generation_output, summarization_output | |
execute_button_multi.click( | |
execute_all_tasks, | |
inputs=[input_text_multi, context_input_multi, label_input_multi, target_language_input_multi], | |
outputs=[output_qa, output_classification, output_translation, output_generation, output_summarization] | |
) | |
demo.launch() |