File size: 4,630 Bytes
a7b8ada
 
 
8962d34
a7b8ada
8962d34
a7b8ada
4a16ef8
8962d34
a7b8ada
8962d34
a7b8ada
8962d34
a7b8ada
8962d34
a7b8ada
 
 
b71e116
4a16ef8
d7dce5e
4a16ef8
 
 
 
 
 
b71e116
d7dce5e
 
b71e116
4a16ef8
8962d34
4a16ef8
 
8962d34
 
4a16ef8
616e7e7
 
4a16ef8
 
 
 
 
616e7e7
4a16ef8
 
 
 
 
 
 
 
616e7e7
 
4a16ef8
 
 
 
 
 
 
 
 
 
a7b8ada
4a16ef8
 
 
 
 
 
 
 
a7b8ada
4a16ef8
8962d34
4a16ef8
 
8962d34
4a16ef8
8962d34
 
4a16ef8
8962d34
4a16ef8
 
 
 
 
 
616e7e7
b71e116
4a16ef8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import random

import torch
import torchvision.transforms as transforms

from PIL import Image
from models.tag2text import tag2text_caption, ram

import gradio as gr

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image_size = 384

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])

#######Tag2Text Model
pretrained = 'tag2text_swin_14m.pth'

model_tag2text = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )

model_tag2text.eval()
model_tag2text = model_tag2text.to(device)


#######RAM Model
pretrained = 'ram_swin_large_14m.pth'

model_ram = ram(pretrained=pretrained, image_size=image_size, vit='swin_l' )

model_ram.eval()
model_ram = model_ram.to(device)


def inference(raw_image, model_n , input_tag):
    raw_image = raw_image.resize((image_size, image_size))

    image = transform(raw_image).unsqueeze(0).to(device) 
    if model_n == 'Recognize Anything Model':  
        model = model_ram
        tags, tags_chinese = model.generate_tag(image)
        return tags[0],tags_chinese[0], 'none'
    else:
        model = model_tag2text
        model.threshold = 0.68
        if input_tag == '' or input_tag == 'none' or input_tag == 'None':
            input_tag_list = None
        else:
            input_tag_list = []
            input_tag_list.append(input_tag.replace(',',' | '))
        with torch.no_grad():


            caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
            if input_tag_list == None:
                tag_1 = tag_predict
                tag_2 = ['none']
            else:
                _, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
                tag_2 = tag_predict

        return tag_1[0],'none',caption[0]


inputs = [
    gr.inputs.Image(type='pil'), 
    gr.inputs.Radio(choices=['Recognize Anything Model',"Tag2Text Model"],
    type="value",
    default="Recognize Anything Model",
    label="Model"          ),          
    gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas, Currently only Tag2Text is supported)")   
                ]

outputs = [gr.outputs.Textbox(label="Tags"),gr.outputs.Textbox(label="标签"), gr.outputs.Textbox(label="Caption (currently only Tag2Text is supported)")]

# title = "Recognize Anything Model"
title = "<font size='10'> Recognize Anything Model</font>"

description = "Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo! <li><b>Recognize Anything Model:</b>  Upload your image to get the <b>English and Chinese outputs of the image tags</b>!</li><li><b>Tag2Text Model:</b> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption.</li> "


article = "<p style='text-align: center'>RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a> | <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"

demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[
                                                                                                                ['images/demo1.jpg',"Recognize Anything Model","none"],
                                                                                                                ['images/demo2.jpg',"Recognize Anything Model","none"],
                                                                                                                 ['images/demo4.jpg',"Recognize Anything Model","none"],
                                                                                                                 ['images/demo4.jpg',"Tag2Text Model","power line"],
                                                                                                                 ['images/demo4.jpg',"Tag2Text Model","track, train"] ,   
                                                                                                                ])

demo.launch(enable_queue=True)