Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from matplotlib import cm | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from model import ImageModel, TextModel | |
| import torch.nn.functional as F | |
| import torchvision.transforms.v2 as transforms | |
| # Load model directly | |
| MODEL_NAME = "distilbert/distilroberta-base" | |
| class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Romance', 'Sci-Fi'] | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| cp = torch.load(r"model_only.pt", map_location="cpu") | |
| model_img = ImageModel(len(class_names)) | |
| model_img.load_state_dict(cp['w_i']) | |
| model_text = TextModel(MODEL_NAME, len(class_names)) | |
| model_text.load_state_dict(cp['w_t']) | |
| image_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| def text_predictor(title, synopsis): | |
| encoded_synopsis = tokenizer(f"{title} </s> {synopsis}", \ | |
| add_special_tokens = True, \ | |
| max_length = 128, \ | |
| padding = "max_length", \ | |
| truncation = True, | |
| return_tensors='pt') | |
| with torch.no_grad(): | |
| score, isAward, genres = model_text((encoded_synopsis['input_ids'], encoded_synopsis['attention_mask'])) | |
| score, isAward, genres = score.squeeze(0), F.sigmoid(isAward.squeeze(0)) >= 0.5 , F.sigmoid(genres.squeeze(0)) | |
| preds_name = [] | |
| for prob, cls in zip(genres, class_names): | |
| if prob >= 0.5: | |
| preds_name.append(cls) | |
| return round(score.item(), 2), isAward.item(), {"genres":preds_name} | |
| def img_predictor(img): | |
| # Preprocess the image | |
| img = Image.fromarray(img.astype('uint8'), 'RGB') # Convert NumPy array to PIL Image | |
| img = image_transforms(img).unsqueeze(0) # Apply transforms and add batch dimension | |
| # Make predictions | |
| with torch.no_grad(): | |
| output = model_img(img) | |
| score, isAward, genres = output[0].squeeze(0), F.sigmoid(output[1].squeeze(0)) >= 0.5, F.sigmoid(output[2].squeeze(0)) | |
| preds_name = [] | |
| for prob, cls in zip(genres, class_names): | |
| if prob >= 0.5: | |
| preds_name.append(cls) | |
| return round(score.item(), 2), isAward.item(), {"genres": preds_name} | |
| def combine_predictor(title, synopsis, img): | |
| encoded_synopsis = tokenizer(f"{title} </s> {synopsis}", \ | |
| add_special_tokens = True, \ | |
| max_length = 128, \ | |
| padding = "max_length", \ | |
| truncation = True, | |
| return_tensors='pt') | |
| img = Image.fromarray(img.astype('uint8'), 'RGB') # Convert NumPy array to PIL Image | |
| img = image_transforms(img).unsqueeze(0) # Apply transforms and add batch dimension | |
| # Make predictions | |
| with torch.no_grad(): | |
| output_text = model_text((encoded_synopsis['input_ids'], encoded_synopsis['attention_mask'])) | |
| output_img = model_img(img) | |
| score = (output_img[0].squeeze(0) + output_text[0].squeeze(0))/2 | |
| isAward = F.sigmoid((output_img[1].squeeze(0) + output_text[1].squeeze(0))/2) >= 0.5 | |
| genres = F.sigmoid((output_img[2].squeeze(0) + output_text[2].squeeze(0))/2) | |
| preds_name = [] | |
| for prob, cls in zip(genres, class_names): | |
| if prob >= 0.5: | |
| preds_name.append(cls) | |
| return round(score.item(), 2), isAward.item(), {"genres": preds_name} | |
| # iface_1 = gr.Interface(age_predictor_image, gr.Image(height=256, width=256), "json", examples=[["young.webp"], ["old.jpg"]]) | |
| iface_1 = gr.Interface(text_predictor, [gr.Text(placeholder="Input title here"), gr.Text(placeholder="Input synopsis here")], [gr.Label(label='Score'), gr.Label(label='Is Winning Award?'), "json"]) | |
| iface_2 = gr.Interface(img_predictor, gr.Image(height=224, width=224), [gr.Label(label='Score'), gr.Label(label='Is Winning Award?'), "json"]) | |
| iface_3 = gr.Interface(combine_predictor, [gr.Text(placeholder="Input title here"), gr.Text(placeholder="Input synopsis here"), gr.Image(height=224, width=224)], [gr.Label(label='Score'), gr.Label(label='Is Winning Award?'), "json"]) | |
| demo = gr.TabbedInterface([iface_1, iface_2, iface_3], ["From Text", "From Image", "From Text and Image"]) | |
| demo.launch() # Launches the mini app! | |