Spaces:
Runtime error
Runtime error
File size: 3,730 Bytes
a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada 8962d34 a7b8ada b71e116 a7b8ada 79c9925 b71e116 6cfdc97 8962d34 616e7e7 b8780b2 616e7e7 f7d54e4 616e7e7 a7b8ada f7d54e4 616e7e7 a7b8ada 616e7e7 a7b8ada 8962d34 616e7e7 8962d34 616e7e7 8962d34 616e7e7 f1d7a0d 8962d34 f1d7a0d 8962d34 616e7e7 7185cc1 809bad6 6920a9d 616e7e7 b71e116 4ff9edd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import numpy as np
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
#######Swin Version
pretrained = 'tag2text_swin_14m.pth'
model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
model.eval()
model = model.to(device)
def inference(raw_image, input_tag):
raw_image = raw_image.resize((image_size, image_size))
image = transform(raw_image).unsqueeze(0).to(device)
model.threshold = 0.68
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
input_tag_list = None
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',',' | '))
with torch.no_grad():
caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
tag_2 = tag_predict
return tag_1[0],tag_2[0],caption[0]
inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas)")]
outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Specified Tags"), gr.outputs.Textbox(label="Image Caption") ]
title = "Tag2Text"
description = "Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption."
article = "<p style='text-align: center'>Tag2text training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://arxiv.org/abs/2303.05657' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000483108.jpg',"none"],
['images/COCO_val2014_000000483108.jpg',"power line"],
['images/COCO_val2014_000000483108.jpg',"track, train"] ,
['images/bdf391a6f4b1840a.jpg',"none"],
['images/64891_194270823.jpg',"none"],
['images/2800737_834897251.jpg',"none"],
['images/1641173_2291260800.jpg',"none"],
])
demo.launch(enable_queue=True)
|