File size: 5,561 Bytes
a7b8ada
 
 
 
8962d34
a7b8ada
8962d34
a7b8ada
 
8962d34
a7b8ada
8962d34
a7b8ada
8962d34
a7b8ada
8962d34
 
a7b8ada
 
 
b71e116
 
a7b8ada
 
b71e116
a7b8ada
 
b71e116
 
a7b8ada
 
 
8962d34
 
 
 
 
a7b8ada
8962d34
a7b8ada
8962d34
a7b8ada
 
 
 
 
 
8962d34
a7b8ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8962d34
 
 
 
 
 
 
a7b8ada
 
 
 
 
 
 
8962d34
a7b8ada
8962d34
a7b8ada
 
 
 
 
 
 
 
 
8962d34
 
 
b71e116
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import ruamel_yaml as yaml
import numpy as np
import random

import torch
import torchvision.transforms as transforms

from PIL import Image
from models.tag2text import tag2text_caption

import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 384


normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])


#######Swin Version
pretrained = '/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth'

config_file = 'configs/tag2text_caption.yaml'
config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)


model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit=config['vit'], 
                    vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
                    prompt=config['prompt'],config=config,threshold = 0.75 )

model.eval()
model = model.to(device)


def inference(raw_image, model_n, input_tag, strategy):
    if model_n == 'Image Captioning':
        raw_image = raw_image.resize((image_size, image_size))
        image = transform(raw_image).unsqueeze(0).to(device)   
        model.threshold = 0.7
        if input_tag == '' or input_tag == 'none' or input_tag == 'None':
            input_tag_list = None
        else:
            input_tag_list = []
            input_tag_list.append(input_tag.replace(',',' | '))
        with torch.no_grad():
            if strategy == "Beam search":
                

                caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)
                if input_tag_list == None:
                    tag_1 = tag_predict
                    tag_2 = ['none']
                else:
                    _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
                    tag_2 = tag_predict

            else:

                caption,tag_predict = model.generate(image,  tag_input = input_tag_list,sample=True, top_p=0.9, max_length=20, min_length=5, return_tag_predict = True)
                if input_tag_list == None:
                    tag_1 = tag_predict
                    tag_2 = ['none']
                else:
                    _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
                    tag_2 = tag_predict
            return tag_1[0],tag_2[0],caption[0]
            

    else:   
        image_vq = transform_vq(raw_image).unsqueeze(0).to(device)  
        with torch.no_grad():
            answer = model_vq(image_vq, question, train=False, inference='generate') 
        return  'answer: '+answer[0]
    
inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type="value", default="Image Captioning", label="Task"),gr.inputs.Textbox(lines=2, label="User Identified Tags (Optional, Enter with commas)"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Beam search", label="Caption Decoding Strategy")]

outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Identified Tags"), gr.outputs.Textbox(label="Image Caption") ]

title = "Tag2Text"

description = "Gradio demo for Tag2Text: Guiding Language-Image Model via Image Tagging (Fudan University, OPPO Research Institute, International Digital Economy Academy)."

article = "<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>"

demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000551338.jpg',"Image Captioning","none","Beam search"], 
                                                                                                                ['images/COCO_val2014_000000551338.jpg',"Image Captioning","fence, sky","Beam search"],
                                                                                                                # ['images/COCO_val2014_000000551338.jpg',"Image Captioning","grass","Beam search"],
                                                                                                                 ['images/COCO_val2014_000000483108.jpg',"Image Captioning","none","Beam search"],
                                                                                                                 ['images/COCO_val2014_000000483108.jpg',"Image Captioning","electric cable","Beam search"],
                                                                                                                  # ['images/COCO_val2014_000000483108.jpg',"Image Captioning","sky, train","Beam search"],
                                                                                                                 ['images/COCO_val2014_000000483108.jpg',"Image Captioning","track, train","Beam search"] ,    
                                                                                                                 ['images/COCO_val2014_000000483108.jpg',"Image Captioning","grass","Beam search"]     
                                                                                                                ])