Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from transformers import CLIPProcessor, CLIPModel, BartTokenizer, BartForConditionalGeneration, GPT2LMHeadModel, GPT2Tokenizer, T5ForConditionalGeneration, T5Tokenizer | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from torchvision.transforms import functional as F | |
# Load the RoBERTa model | |
roberta_model = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
# Define the interface function for RoBERTa | |
def answer_question_roberta(context, question): | |
result = roberta_model(question=question, context=context) | |
return result["answer"] | |
# Create the Gradio interface for RoBERTa | |
roberta_interface = gr.Interface( | |
fn=answer_question_roberta, | |
inputs=["text", "text"], | |
outputs="text", | |
title="Question Answering with RoBERTa", | |
description="Ask a question about the given context.", | |
) | |
# Placeholder interfaces for other models | |
placeholder_interface1 = gr.Interface( | |
fn=lambda x: x, # Placeholder function | |
inputs="text", | |
outputs="text", | |
title="Model 1", | |
description="Placeholder for Model 1.", | |
) | |
# Load the model and processor | |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") | |
# Define the interface function | |
def classify_image(image: Image.Image, labels: str): | |
# Prepare the image | |
transform = Compose([ | |
Resize([224, 224]), | |
CenterCrop(224), | |
ToTensor(), | |
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
image = transform(image) | |
# Prepare the labels | |
labels = labels.split(',') | |
# Encode the image and labels | |
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
# Get the model's output | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
# Get the predicted label | |
predicted_label = labels[torch.argmax(logits_per_image).item()] | |
return predicted_label | |
patch16_interface = gr.Interface( | |
fn=classify_image, # The function for image classification | |
inputs=["image", "text"], # Input components | |
outputs="text", # Output component | |
title="Image Classification with CLIP", # Title of the interface | |
description="Upload an image and enter a list of labels (comma-separated). The model will predict the label that best matches the image.", # Description of the interface | |
) | |
# Repeat for other placeholder interfaces... | |
# Load the model and tokenizer | |
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") | |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") | |
# Define the interface function | |
def summarize_text(input_text: str): | |
# Encode the text | |
inputs = tokenizer([input_text], max_length=1024, return_tensors='pt') | |
# Get the model's output | |
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=150, early_stopping=True) | |
# Decode the output | |
summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] | |
return summary[0] | |
# Create the Gradio interface | |
bart_large_cnn_interface = gr.Interface( | |
fn=summarize_text, | |
inputs="text", | |
outputs="text", | |
title="Text Summarization with BART", | |
description="Enter a long piece of text. The model will generate a summary.", | |
) | |
# Load the model | |
# model = fasterrcnn_resnet50_fpn(pretrained=True) | |
# model.eval() | |
# Define the interface function | |
# def detect_objects(input_image: Image.Image): | |
# Convert the image to a tensor | |
# input_tensor = F.to_tensor(input_image) | |
# Add an extra dimension at the beginning of the tensor | |
# input_tensor = input_tensor.unsqueeze(0) | |
# Get the model's output | |
# output = model(input_tensor) | |
# Get the bounding boxes | |
# boxes = output[0]["boxes"] | |
# Draw the bounding boxes on the image | |
# for box in boxes: | |
# input_image.draw.rectangle(list(box.detach().numpy()), outline="red") | |
# return input_image | |
# Create the Gradio interface | |
# detr_resnet50_interface = gr.Interface( | |
# fn=detect_objects, | |
# inputs="image", | |
# outputs="image", | |
# title="Object Detection with DETR", | |
# description="Upload an image. The model will detect objects in the image.", | |
#) | |
# Load the model and tokenizer | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
# Define the interface function | |
def generate_text(prompt: str): | |
# Encode the prompt | |
inputs = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate a sequence of tokens | |
outputs = model.generate(inputs, max_length=150, temperature=0.7, num_return_sequences=1) | |
# Decode the tokens into a string | |
text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return text | |
# Create the Gradio interface | |
gpt2_interface = gr.Interface( | |
fn=generate_text, | |
inputs="text", | |
outputs="text", | |
title="Text Generation with GPT-2", | |
description="Enter a prompt and the model will generate a continuation of the text.", | |
) | |
# Load the model and tokenizer | |
model = T5ForConditionalGeneration.from_pretrained("vennify/t5-base-grammar-correction") | |
tokenizer = T5Tokenizer.from_pretrained("vennify/t5-base-grammar-correction") | |
# Define the interface function | |
def correct_grammar(input_text: str): | |
# Encode the text | |
inputs = tokenizer.encode("correct: " + input_text, return_tensors="pt") | |
# Generate a sequence of tokens | |
outputs = model.generate(inputs, max_length=512, num_beams=4, early_stopping=True) | |
# Decode the tokens into a string | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return corrected_text | |
correct_grammar = gr.Interface( | |
fn=correct_grammar, | |
inputs="text", | |
outputs="text" | |
) | |
# Define the interface function | |
def multi_model_interface(input_text: str): | |
# Roberta | |
roberta_inputs = roberta_tokenizer(input_text, return_tensors="pt") | |
roberta_outputs = roberta_model(**roberta_inputs) | |
_, roberta_preds = torch.max(roberta_outputs.logits, dim=1) | |
roberta_result = f"Class: {roberta_preds.item()}" | |
# Bart | |
bart_inputs = bart_tokenizer([input_text], max_length=1024, return_tensors='pt') | |
bart_summary_ids = bart_model.generate(bart_inputs['input_ids'], num_beams=4, max_length=150, early_stopping=True) | |
bart_summary = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in bart_summary_ids] | |
# GPT-2 | |
gpt2_inputs = gpt2_tokenizer.encode(input_text, return_tensors="pt") | |
gpt2_outputs = gpt2_model.generate(gpt2_inputs, max_length=150, temperature=0.7, num_return_sequences=1) | |
gpt2_text = gpt2_tokenizer.decode(gpt2_outputs[0], skip_special_tokens=True) | |
# T5 | |
t5_inputs = t5_tokenizer.encode("correct: " + input_text, return_tensors="pt") | |
t5_outputs = t5_model.generate(t5_inputs, max_length=512, num_beams=4, early_stopping=True) | |
t5_corrected_text = t5_tokenizer.decode(t5_outputs[0], skip_special_tokens=True) | |
return {"Roberta Classification": roberta_result, "Bart Summary": bart_summary[0], "GPT-2 Generation": gpt2_text, "T5 Correction": t5_corrected_text} | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=multi_model_interface, | |
inputs="text", | |
outputs="text", | |
title="Multi-Model Interface", | |
description="Enter a text and the interface will display the output from each of the four models.", | |
) | |
# Combine interfaces into a tabbed interface | |
demo = gr.TabbedInterface( | |
[roberta_interface, patch16_interface, bart_large_cnn_interface, gpt2_interface, correct_grammar], | |
["Single-Model: Question Answering", "Single-Model: Image Classification", "Single-Model: Text Summarization", "Single-Model: Text Generation", "Single-Model: Correct Grammar", "Computer Vision: Object Detection"] | |
) | |
# Launch the tabbed interface | |
if __name__ == "__main__": | |
demo.launch() | |
# Launch the interface | |
iface.launch() |