xinyu1205's picture
Update app.py
d7dce5e
raw
history blame
4.63 kB
import numpy as np
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
from models.tag2text import tag2text_caption, ram
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
#######Tag2Text Model
pretrained = 'tag2text_swin_14m.pth'
model_tag2text = tag2text_caption(pretrained=pretrained, image_size=image_size, vit='swin_b' )
model_tag2text.eval()
model_tag2text = model_tag2text.to(device)
#######RAM Model
pretrained = 'ram_swin_large_14m.pth'
model_ram = ram(pretrained=pretrained, image_size=image_size, vit='swin_l' )
model_ram.eval()
model_ram = model_ram.to(device)
def inference(raw_image, model_n , input_tag):
raw_image = raw_image.resize((image_size, image_size))
image = transform(raw_image).unsqueeze(0).to(device)
if model_n == 'Recognize Anything Model':
model = model_ram
tags, tags_chinese = model.generate_tag(image)
return tags[0],tags_chinese[0], 'none'
else:
model = model_tag2text
model.threshold = 0.68
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
input_tag_list = None
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',',' | '))
with torch.no_grad():
caption, tag_predict = model.generate(image,tag_input = input_tag_list,max_length = 50, return_tag_predict = True)
if input_tag_list == None:
tag_1 = tag_predict
tag_2 = ['none']
else:
_, tag_1 = model.generate(image,tag_input = None, max_length = 50, return_tag_predict = True)
tag_2 = tag_predict
return tag_1[0],'none',caption[0]
inputs = [
gr.inputs.Image(type='pil'),
gr.inputs.Radio(choices=['Recognize Anything Model',"Tag2Text Model"],
type="value",
default="Recognize Anything Model",
label="Model" ),
gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas, Currently only Tag2Text is supported)")
]
outputs = [gr.outputs.Textbox(label="Tags"),gr.outputs.Textbox(label="标签"), gr.outputs.Textbox(label="Caption (currently only Tag2Text is supported)")]
# title = "Recognize Anything Model"
title = "<font size='10'> Recognize Anything Model</font>"
description = "Welcome to the Recognize Anything Model (RAM) and Tag2Text Model demo! <li><b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese outputs of the image tags</b>!</li><li><b>Tag2Text Model:</b> Upload your image to get the <b>tags</b> and <b>caption</b> of the image. Optional: You can also input specified tags to get the corresponding caption.</li> "
article = "<p style='text-align: center'>RAM and Tag2Text is training on open-source datasets, and we are persisting in refining and iterating upon it.<br/><a href='https://recognize-anything.github.io/' target='_blank'>Recognize Anything: A Strong Image Tagging Model</a> | <a href='https://https://tag2text.github.io/' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='https://github.com/xinyu1205/Tag2Text' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[
['images/demo1.jpg',"Recognize Anything Model","none"],
['images/demo2.jpg',"Recognize Anything Model","none"],
['images/demo4.jpg',"Recognize Anything Model","none"],
['images/demo4.jpg',"Tag2Text Model","power line"],
['images/demo4.jpg',"Tag2Text Model","track, train"] ,
])
demo.launch(enable_queue=True)