File size: 3,730 Bytes
a7b8ada
 
 
8962d34
a7b8ada
8962d34
a7b8ada
 
8962d34
a7b8ada
8962d34
a7b8ada
8962d34
a7b8ada
8962d34
a7b8ada
 
 
b71e116
 
a7b8ada
79c9925
b71e116
6cfdc97
8962d34
 
 
 
 
616e7e7
 
 
 
b8780b2
616e7e7
 
 
 
 
 
 
 
f7d54e4
616e7e7
 
 
a7b8ada
f7d54e4
616e7e7
a7b8ada
616e7e7
a7b8ada
8962d34
616e7e7
8962d34
616e7e7
8962d34
616e7e7
f1d7a0d
8962d34
f1d7a0d
8962d34
616e7e7
7185cc1
809bad6
 
6920a9d
 
 
616e7e7
b71e116
4ff9edd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import random

import torch
import torchvision.transforms as transforms

from PIL import Image
from models.tag2text import tag2text_caption

import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 384

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])


#######Swin Version
pretrained = 'tag2text_swin_14m.pth'

model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )

model.eval()
model = model.to(device)


def inference(raw_image, input_tag):
    raw_image = raw_image.resize((image_size, image_size))

    image = transform(raw_image).unsqueeze(0).to(device)   
    model.threshold = 0.68
    if input_tag == '' or input_tag == 'none' or input_tag == 'None':
        input_tag_list = None
    else:
        input_tag_list = []
        input_tag_list.append(input_tag.replace(',',' | '))
    with torch.no_grad():


        caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
        if input_tag_list == None:
            tag_1 = tag_predict
            tag_2 = ['none']
        else:
            _, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
            tag_2 = tag_predict

        return tag_1[0],tag_2[0],caption[0]


inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas)")]

outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Specified Tags"), gr.outputs.Textbox(label="Image Caption") ]

title = "Tag2Text"
description = "Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption."

article = "<p style='text-align: center'>Tag2text training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://arxiv.org/abs/2303.05657' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"

demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000483108.jpg',"none"],
                                                                                                                 ['images/COCO_val2014_000000483108.jpg',"power line"],
                                                                                                                 ['images/COCO_val2014_000000483108.jpg',"track, train"] ,   
                                                                                                                 ['images/bdf391a6f4b1840a.jpg',"none"],
                                                                                                                 ['images/64891_194270823.jpg',"none"],
                                                                                                                 ['images/2800737_834897251.jpg',"none"],
                                                                                                                 ['images/1641173_2291260800.jpg',"none"],
                                                                                                                ])

demo.launch(enable_queue=True)