xinyu1205's picture
Update app.py
6cfdc97
raw history blame
No virus
2.87 kB
import numpy as np
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
#######Swin Version
pretrained = 'tag2text_swin_14m.pth'
model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
model.eval()
model = model.to(device)
def inference(raw_image, input_tag):
raw_image = raw_image.resize((image_size, image_size))
image = transform(raw_image).unsqueeze(0).to(device)
model.threshold = 0.69
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
input_tag_list = None
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',',' | '))
with torch.no_grad():
caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
tag_2 = tag_predict
return tag_1[0],tag_2[0],caption[0]
inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas)")]
outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Specified Tags"), gr.outputs.Textbox(label="Image Caption") ]
title = "Tag2Text"
description = "Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the tags and caption of the image. Optional: You can also input specified tags to get the corresponding caption."
article = "<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000483108.jpg',"none"],
['images/COCO_val2014_000000483108.jpg',"electric cable"],
['images/COCO_val2014_000000483108.jpg',"track, train"] ,
])