import os import sys import json import torch import kelip import gradio as gr def load_model(): model, preprocess_img, tokenizer = kelip.build_model('ViT-B/32') device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() model_dict = {'model': model, 'preprocess_img': preprocess_img, 'tokenizer': tokenizer } return model_dict def classify(img, user_text): preprocess_img = model_dict['preprocess_img'] input_img = preprocess_img(img).unsqueeze(0) device = "cuda" if torch.cuda.is_available() else "cpu" input_img = input_img.to(device) # extract image features with torch.no_grad(): image_features = model_dict['model'].encode_image(input_img) # extract text features user_texts = user_text.split(',') if user_text == '' or user_text.isspace(): user_texts = [] input_texts = model_dict['tokenizer'].encode(user_texts) if torch.cuda.is_available(): input_texts = input_texts.cuda() text_features = model_dict['model'].encode_text(input_texts) # l2 normalize image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(len(user_texts)) result = {} for value, index in zip(values, indices): result[user_texts[index]] = value.item() return result if __name__ == '__main__': global model_dict model_dict = load_model() inputs = [gr.inputs.Image(type="pil", label="Image"), gr.inputs.Textbox(lines=5, label="Caption"), ] outputs = ['label'] title = "KELIP" if torch.cuda.is_available(): demo_status = "Demo is running on GPU" else: demo_status = "Demo is running on CPU" description = f"Details: paper_url. {demo_status}" examples = [ ["squid_sundae.jpg", "오징어 순대,김밥,순대,떡볶이"], ["seokchon_lake.jpg", "평화의문,올림픽공원,롯데월드,석촌호수"], ["seokchon_lake.jpg", "봄,여름,가을,겨울"], ["hwangchil_tree.jpg", "황칠 나무 묘목,황칠 나무,난,소나무 묘목,야자수"], ] article = "" iface=gr.Interface( fn=classify, inputs=inputs, outputs=outputs, examples=examples, title=title, description=description, article=article ) iface.launch()