Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from matplotlib import cm | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from model import ImageModel, TextModel | |
import torch.nn.functional as F | |
import torchvision.transforms.v2 as transforms | |
# Load model directly | |
MODEL_NAME = "distilbert/distilroberta-base" | |
class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Romance', 'Sci-Fi'] | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
cp = torch.load(r"model_only.pt", map_location="cpu") | |
model_img = ImageModel(len(class_names)) | |
model_img.load_state_dict(cp['w_i']) | |
model_text = TextModel(MODEL_NAME, len(class_names)) | |
model_text.load_state_dict(cp['w_t']) | |
image_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
]) | |
def text_predictor(title, synopsis): | |
encoded_synopsis = tokenizer(f"{title} </s> {synopsis}", \ | |
add_special_tokens = True, \ | |
max_length = 128, \ | |
padding = "max_length", \ | |
truncation = True, | |
return_tensors='pt') | |
with torch.no_grad(): | |
score, isAward, genres = model_text((encoded_synopsis['input_ids'], encoded_synopsis['attention_mask'])) | |
score, isAward, genres = score.squeeze(0), F.sigmoid(isAward.squeeze(0)) >= 0.5 , F.sigmoid(genres.squeeze(0)) | |
preds_name = [] | |
for prob, cls in zip(genres, class_names): | |
if prob >= 0.5: | |
preds_name.append(cls) | |
return round(score.item(), 2), isAward.item(), {"genres":preds_name} | |
def img_predictor(img): | |
# Preprocess the image | |
img = Image.fromarray(img.astype('uint8'), 'RGB') # Convert NumPy array to PIL Image | |
img = image_transforms(img).unsqueeze(0) # Apply transforms and add batch dimension | |
# Make predictions | |
with torch.no_grad(): | |
output = model_img(img) | |
score, isAward, genres = output[0].squeeze(0), F.sigmoid(output[1].squeeze(0)) >= 0.5, F.sigmoid(output[2].squeeze(0)) | |
preds_name = [] | |
for prob, cls in zip(genres, class_names): | |
if prob >= 0.5: | |
preds_name.append(cls) | |
return round(score.item(), 2), isAward.item(), {"genres": preds_name} | |
def combine_predictor(title, synopsis, img): | |
encoded_synopsis = tokenizer(f"{title} </s> {synopsis}", \ | |
add_special_tokens = True, \ | |
max_length = 128, \ | |
padding = "max_length", \ | |
truncation = True, | |
return_tensors='pt') | |
img = Image.fromarray(img.astype('uint8'), 'RGB') # Convert NumPy array to PIL Image | |
img = image_transforms(img).unsqueeze(0) # Apply transforms and add batch dimension | |
# Make predictions | |
with torch.no_grad(): | |
output_text = model_text((encoded_synopsis['input_ids'], encoded_synopsis['attention_mask'])) | |
output_img = model_img(img) | |
score = (output_img[0].squeeze(0) + output_text[0].squeeze(0))/2 | |
isAward = F.sigmoid((output_img[1].squeeze(0) + output_text[1].squeeze(0))/2) >= 0.5 | |
genres = F.sigmoid((output_img[2].squeeze(0) + output_text[2].squeeze(0))/2) | |
preds_name = [] | |
for prob, cls in zip(genres, class_names): | |
if prob >= 0.5: | |
preds_name.append(cls) | |
return round(score.item(), 2), isAward.item(), {"genres": preds_name} | |
# iface_1 = gr.Interface(age_predictor_image, gr.Image(height=256, width=256), "json", examples=[["young.webp"], ["old.jpg"]]) | |
iface_1 = gr.Interface(text_predictor, [gr.Text(placeholder="Input title here"), gr.Text(placeholder="Input synopsis here")], [gr.Label(label='Score'), gr.Label(label='Is Winning Award?'), "json"]) | |
iface_2 = gr.Interface(img_predictor, gr.Image(height=224, width=224), [gr.Label(label='Score'), gr.Label(label='Is Winning Award?'), "json"]) | |
iface_3 = gr.Interface(combine_predictor, [gr.Text(placeholder="Input title here"), gr.Text(placeholder="Input synopsis here"), gr.Image(height=224, width=224)], [gr.Label(label='Score'), gr.Label(label='Is Winning Award?'), "json"]) | |
demo = gr.TabbedInterface([iface_1, iface_2, iface_3], ["From Text", "From Image", "From Text and Image"]) | |
demo.launch() # Launches the mini app! | |