wodesq's picture
Update app.py
a876ea0
from PIL import Image
import torch
from transformers import BertForSequenceClassification, BertConfig, BertTokenizer
from transformers import CLIPProcessor, CLIPModel
import numpy as np
import time
import gradio as gr
import re
# 加载Taiyi 中文 word encoder
text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese").eval()
# 加载CLIP的image encoder
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def imgclassfiy(query_texts,img_url):
start_time = time.time()
query_texts =re.split(",|,",query_texts)
text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids']
url = img_url
image = processor(images=Image.open(url), return_tensors="pt")
with torch.no_grad():
image_features = clip_model.get_image_features(**image)
text_features = text_encoder(text).logits
# 归一化
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# 计算余弦相似度 logit_scale是尺度系数
logit_scale = clip_model.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
#res = np.around(probs, 3)[0]
res = query_texts[np.argmax(probs)]
end_time = time.time()
print('用时:', end_time - start_time)
return res
if __name__ =="__main__":
with gr.Blocks(title="自定义类别的图像分类") as demo:
# 标题
gr.HTML('<br>')
gr.HTML(
f'<center><p style="color:#4377ec;font-size:42px;font-weight:bold;text-shadow: #FDEDB7 2px 0 0, #FDEDB7 0 2px 0, #FDEDB7 -2px 0 0, #FDEDB7 0 -2px 0;">自定义类别的图像分类</p></center>')
gr.HTML('<br>')
with gr.Row() as row:
with gr.Column():
img_input = gr.Image(type="filepath")
out_input = gr.Textbox(lable='自定义类别',placeholder='输入自定义类别,例如:猫,狗,兔子')
text_btn = gr.Button("提交")
with gr.Column(scale=5):
img_out = gr.Textbox(lable='输出类别')
text_btn.click(fn=imgclassfiy, inputs=[out_input,img_input], outputs=[img_out])
demo.launch(show_api=False,inbrowser=True)