from PIL import Image import torch from transformers import BertForSequenceClassification, BertConfig, BertTokenizer from transformers import CLIPProcessor, CLIPModel import numpy as np import time import gradio as gr import re # 加载Taiyi 中文 word encoder text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese") text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese").eval() # 加载CLIP的image encoder clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") def imgclassfiy(query_texts,img_url): start_time = time.time() query_texts =re.split(",|,",query_texts) text = text_tokenizer(query_texts, return_tensors='pt', padding=True)['input_ids'] url = img_url image = processor(images=Image.open(url), return_tensors="pt") with torch.no_grad(): image_features = clip_model.get_image_features(**image) text_features = text_encoder(text).logits # 归一化 image_features = image_features / image_features.norm(dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True) # 计算余弦相似度 logit_scale是尺度系数 logit_scale = clip_model.logit_scale.exp() logits_per_image = logit_scale * image_features @ text_features.t() logits_per_text = logits_per_image.t() probs = logits_per_image.softmax(dim=-1).cpu().numpy() #res = np.around(probs, 3)[0] res = query_texts[np.argmax(probs)] end_time = time.time() print('用时:', end_time - start_time) return res if __name__ =="__main__": with gr.Blocks(title="自定义类别的图像分类") as demo: # 标题 gr.HTML('
') gr.HTML( f'

自定义类别的图像分类

') gr.HTML('
') with gr.Row() as row: with gr.Column(): img_input = gr.Image(type="filepath") out_input = gr.Textbox(lable='自定义类别',placeholder='输入自定义类别,例如:猫,狗,兔子') text_btn = gr.Button("提交") with gr.Column(scale=5): img_out = gr.Textbox(lable='输出类别') text_btn.click(fn=imgclassfiy, inputs=[out_input,img_input], outputs=[img_out]) demo.launch(show_api=False,inbrowser=True)