qingshan777's picture
Update app.py
5f85845 verified
import gradio as gr
import torch
import io
from PIL import Image
from transformers import (
AutoImageProcessor,
AutoTokenizer,
AutoModelForCausalLM,
)
import numpy as np
import ast
model_root = "qihoo360/fg-clip2-base"
model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True)
device = model.device
tokenizer = AutoTokenizer.from_pretrained(model_root)
image_processor = AutoImageProcessor.from_pretrained(model_root)
def determine_max_value(image):
w,h = image.size
max_val = (w//16)*(h//16)
if max_val > 784:
return 1024
elif max_val > 576:
return 784
elif max_val > 256:
return 576
elif max_val > 128:
return 256
else:
return 128
def postprocess_result(probs, labels):
pro_output = {labels[i]: probs[i] for i in range(len(labels))}
return pro_output
def Retrieval(image, candidate_labels, text_type):
"""
Takes an image and a comma-separated string of candidate labels,
and returns the classification scores.
"""
if text_type is None:
text_type = "long"
print(text_type)
image = image.convert("RGB")
image_input = image_processor(images=image, max_num_patches=determine_max_value(image), return_tensors="pt").to(device)
candidate_labels = [candidate_labels.lower() for candidate_labels in candidate_labels]
if text_type=="long":
max_length = 196
else:
max_length = 64
caption_input = tokenizer(candidate_labels, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
image_feature = model.get_image_features(**image_input)
text_feature = model.get_text_features(**caption_input,walk_type=text_type)
image_feature = image_feature / image_feature.norm(p=2, dim=-1, keepdim=True)
text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
logits_per_image = image_feature @ text_feature.T
logit_scale, logit_bias = model.logit_scale.to(text_feature.device), model.logit_bias.to(text_feature.device)
logits_per_image = logits_per_image * logit_scale.exp() + logit_bias
print(logits_per_image)
# probs = torch.sigmoid(logits_per_image)
probs = logits_per_image.softmax(dim=1)
print(probs)
results = probs[0].tolist()
return results
def infer(image, candidate_labels, text_type):
# assert text_type in ["short","long", "box"]
candidate_labels = ast.literal_eval(candidate_labels)
fg_probs = Retrieval(image, candidate_labels,text_type)
return postprocess_result(fg_probs,candidate_labels)
with gr.Blocks() as demo:
gr.Markdown("# FG-CLIP 2 Retrieval")
gr.Markdown(
"This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for retrieval on CPU :"
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil")
text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']")
text_type = gr.Textbox(label="form [short, long, box] select", value="long")
run_button = gr.Button("Run Retrieval", visible=True)
with gr.Column():
fg_output = gr.Label(label="FG-CLIP 2 Output", num_top_classes=11)
examples = [
["./000093.jpg", str([
"一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双浅色鞋子,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
"一个简约风格的卧室角落,黑色金属衣架上挂着多件红色和蓝色的衣物,下方架子放着两双黑色高跟鞋,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
"一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双运动鞋,旁边是一盆仙人掌,左侧可见一张铺有白色床单和灰色枕头的床。",
"一个繁忙的街头市场,摊位上摆满水果,背景是高楼大厦,人们在喧闹中购物。"
]
)],
["./000093.jpg", str([
"A minimalist-style bedroom corner with a black metal clothing rack holding several beige and white garments, two pairs of light-colored shoes on the shelf below, a potted green plant nearby, and to the left, a bed made with white sheets and gray pillows.",
"A minimalist-style bedroom corner with a black metal clothing rack holding several red and blue garments, two pairs of black high heels on the shelf below, a potted green plant nearby, and to the left, a bed made with white sheets and gray pillows.",
"A minimalist-style bedroom corner with a black metal clothing rack holding several beige and white garments, two pairs of sneakers on the shelf below, a potted cactus nearby, and to the left, a bed made with white sheets and gray pillows.",
"A bustling street market with fruit-filled stalls, skyscrapers in the background, and people shopping amid the noise and activity."
]
)],
]
gr.Examples(
examples=examples,
inputs=[image_input, text_input, text_type],
)
run_button.click(fn=infer, inputs=[image_input, text_input, text_type], outputs=fg_output)
# demo.launch(server_name="0.0.0.0", server_port=7861, share=True)
demo.launch()