ZH-CLIP / app.py
nlpcver's picture
first commit
ce03565
import gradio as gr
from typing import List
from PIL import Image
from zhclip import ZhCLIPProcessor, ZhCLIPModel # From https://www.github.com/thu-ml/zh-clip
version = 'thu-ml/zh-clip-vit-roberta-large-patch14'
model = ZhCLIPModel.from_pretrained(version)
processor = ZhCLIPProcessor.from_pretrained(version)
def inference(image, texts):
texts = [x[0] for x in texts]
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
image_features = outputs.image_features
text_features = outputs.text_features
text_probs = (image_features @ text_features.T).softmax(dim=-1)[0].detach().cpu().numpy()
return {i: float(text_probs[i]) for i in range(len(text_probs))}
title = "ZH-CLIP zero-shot classification"
description = "Chinese Clip Model (ZH-CLIP) zero-shot classification"
article="<p style='text-align: center'><a href='https://www.github.com/thu-ml/zh-clip' target='_blank'>github: zh-clip</a> <a href='https://huggingface.co/thu-ml/zh-clip-vit-roberta-large-patch14' target='_blank'>huggingface model: thu-ml/zh-clip-vit-roberta-large-patch14</a></p>"
examples = [['./images/dog.jpeg', [['一只狗'], ['一只猫']]]]
interpretation='default'
enable_queue=True
iface = gr.Interface(fn=inference, inputs=["image", "list"], outputs="label",
title=title, description=description, article=article, examples=examples,
enable_queue=enable_queue)
iface.launch(server_name='0.0.0.0')