Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
import time | |
import clip | |
import requests | |
import csv | |
import json | |
import wget | |
url_dict = {'clip_ViTL14_openimage_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_openimage_classifier_weights.pt', | |
'clip_ViTL14_place365_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_place365_classifier_weights.pt', | |
'clip_ViTL14_tencentml_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_tencentml_classifier_weights.pt'} | |
os.makedirs('./prompts', exist_ok=True) | |
for k, v in url_dict.items(): | |
wget.download(v, out='./prompts') | |
os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" | |
HF_TOKEN = os.environ["HF_TOKEN"] | |
def load_openimage_classnames(csv_path): | |
csv_data = open(csv_path) | |
csv_reader = csv.reader(csv_data) | |
classnames = {idx: row[-1] for idx, row in enumerate(csv_reader)} | |
return classnames | |
def load_tencentml_classnames(txt_path): | |
txt_data = open(txt_path) | |
lines = txt_data.readlines() | |
classnames = {idx: line.strip() for idx, line in enumerate(lines)} | |
return classnames | |
def build_simple_classifier(clip_model, text_list, template, device): | |
with torch.no_grad(): | |
texts = [template(text) for text in text_list] | |
text_inputs = clip.tokenize(texts).to(device) | |
text_features = clip_model.encode_text(text_inputs) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
return text_features, {idx: text for idx, text in enumerate(text_list)} | |
def load_models(): | |
# build model and tokenizer | |
model_dict = {} | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print('\tLoading CLIP ViT-L/14') | |
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) | |
print('\tLoading precomputed zeroshot classifier') | |
openimage_classifier_weights = torch.load('./prompts/clip_ViTL14_openimage_classifier_weights.pt', map_location=device).type(torch.FloatTensor) | |
openimage_classnames = load_openimage_classnames('./prompts/openimage-classnames.csv') | |
tencentml_classifier_weights = torch.load('./prompts/clip_ViTL14_tencentml_classifier_weights.pt', map_location=device).type(torch.FloatTensor) | |
tencentml_classnames = load_tencentml_classnames('./prompts/tencent-ml-classnames.txt') | |
place365_classifier_weights = torch.load('./prompts/clip_ViTL14_place365_classifier_weights.pt', map_location=device).type(torch.FloatTensor) | |
place365_classnames = load_tencentml_classnames('./prompts/place365-classnames.txt') | |
print('\tBuilding simple zeroshot classifier') | |
img_types = ['photo', 'cartoon', 'sketch', 'painting'] | |
ppl_texts = ['no people', 'people'] | |
ifppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people'] | |
imgtype_classifier_weights, imgtype_classnames = build_simple_classifier(clip_model, img_types, lambda c: f'This is a {c}.', device) | |
ppl_classifier_weights, ppl_classnames = build_simple_classifier(clip_model, ppl_texts, lambda c: f'There are {c} in this photo.', device) | |
ifppl_classifier_weights, ifppl_classnames = build_simple_classifier(clip_model, ifppl_texts, lambda c: f'There {c} in this photo.', device) | |
model_dict['clip_model'] = clip_model | |
model_dict['clip_preprocess'] = clip_preprocess | |
model_dict['openimage_classifier_weights'] = openimage_classifier_weights | |
model_dict['openimage_classnames'] = openimage_classnames | |
model_dict['tencentml_classifier_weights'] = tencentml_classifier_weights | |
model_dict['tencentml_classnames'] = tencentml_classnames | |
model_dict['place365_classifier_weights'] = place365_classifier_weights | |
model_dict['place365_classnames'] = place365_classnames | |
model_dict['imgtype_classifier_weights'] = imgtype_classifier_weights | |
model_dict['imgtype_classnames'] = imgtype_classnames | |
model_dict['ppl_classifier_weights'] = ppl_classifier_weights | |
model_dict['ppl_classnames'] = ppl_classnames | |
model_dict['ifppl_classifier_weights'] = ifppl_classifier_weights | |
model_dict['ifppl_classnames'] = ifppl_classnames | |
model_dict['device'] = device | |
return model_dict | |
def drop_gpu(tensor): | |
if torch.cuda.is_available(): | |
return tensor.cpu().numpy() | |
else: | |
return tensor.numpy() | |
def zeroshot_classifier(image): | |
image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device']) | |
with torch.no_grad(): | |
image_features = model_dict['clip_model'].encode_image(image_input) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
sim = (100.0 * image_features @ model_dict['openimage_classifier_weights'].T).softmax(dim=-1) | |
openimage_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] | |
openimage_classes = [model_dict['openimage_classnames'][idx] for idx in indices] | |
sim = (100.0 * image_features @ model_dict['tencentml_classifier_weights'].T).softmax(dim=-1) | |
tencentml_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] | |
tencentml_classes = [model_dict['tencentml_classnames'][idx] for idx in indices] | |
sim = (100.0 * image_features @ model_dict['place365_classifier_weights'].T).softmax(dim=-1) | |
place365_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] | |
place365_classes = [model_dict['place365_classnames'][idx] for idx in indices] | |
sim = (100.0 * image_features @ model_dict['imgtype_classifier_weights'].T).softmax(dim=-1) | |
imgtype_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['imgtype_classnames']))] | |
imgtype_classes = [model_dict['imgtype_classnames'][idx] for idx in indices] | |
sim = (100.0 * image_features @ model_dict['ppl_classifier_weights'].T).softmax(dim=-1) | |
ppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ppl_classnames']))] | |
ppl_classes = [model_dict['ppl_classnames'][idx] for idx in indices] | |
sim = (100.0 * image_features @ model_dict['ifppl_classifier_weights'].T).softmax(dim=-1) | |
ifppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ifppl_classnames']))] | |
ifppl_classes = [model_dict['ifppl_classnames'][idx] for idx in indices] | |
return image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes,\ | |
place365_scores, place365_classes, imgtype_scores, imgtype_classes,\ | |
ppl_scores, ppl_classes, ifppl_scores, ifppl_classes | |
def generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes): | |
img_type = imgtype_classes[0] | |
ppl_result = ppl_classes[0] | |
if ppl_result == 'people': | |
ppl_result = ifppl_classes[0] | |
else: | |
ppl_result = 'are %s' % ppl_result | |
sorted_places = place365_classes | |
object_list = '' | |
for cls in tencentml_classes: | |
object_list += f'{cls}, ' | |
for cls in openimage_classes[:2]: | |
object_list += f'{cls}, ' | |
object_list = object_list[:-2] | |
prompt_caption = f'''I am an intelligent image captioning bot. | |
This image is a {img_type}. There {ppl_result}. | |
I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}. | |
I think there might be a {object_list} in this {img_type}. | |
A creative short caption I can generate to describe this image is:''' | |
#prompt_search = f'''Let's list keywords that include the following description. | |
#This image is a {img_type}. There {ppl_result}. | |
#I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}. | |
#I think there might be a {object_list} in this {img_type}. | |
#Relevant keywords which we can list and are seperated with comma are:''' | |
return prompt_caption | |
def generate_captions(prompt, num_captions=3): | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
max_length = 16 | |
seed = 42 | |
sample_or_greedy = 'Greedy' | |
input_sentence = prompt | |
if sample_or_greedy == "Sample": | |
parameters = { | |
"max_new_tokens": max_length, | |
"top_p": 0.7, | |
"do_sample": True, | |
"seed": seed, | |
"early_stopping": False, | |
"length_penalty": 0.0, | |
"eos_token_id": None, | |
} | |
else: | |
parameters = { | |
"max_new_tokens": max_length, | |
"do_sample": False, | |
"seed": seed, | |
"early_stopping": False, | |
"length_penalty": 0.0, | |
"eos_token_id": None, | |
} | |
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False}} | |
bloom_results = [] | |
for _ in range(num_captions): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
output = response.json() | |
generated_text = output[0]['generated_text'].replace(prompt, '').split('.')[0] + '.' | |
bloom_results.append(generated_text) | |
return bloom_results | |
def sorting_texts(image_features, captions): | |
with torch.no_grad(): | |
text_inputs = clip.tokenize(captions).to(model_dict['device']) | |
text_features = model_dict['clip_model'].encode_text(text_inputs) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
sim = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(captions))] | |
sorted_captions = [captions[idx] for idx in indices] | |
return scores, sorted_captions | |
def postprocess_results(scores, classes): | |
scores = [float('%.4f' % float(val)) for val in scores] | |
outputs = [] | |
for score, cls in zip(scores, classes): | |
outputs.append({'score': score, 'output': cls}) | |
return outputs | |
def image_captioning(image): | |
start_time = time.time() | |
image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes, place365_scores, place365_classes, imgtype_scores, imgtype_classes, ppl_scores, ppl_classes, ifppl_scores, ifppl_classes = zeroshot_classifier(image) | |
end_zeroshot = time.time() | |
prompt_caption = generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes) | |
generated_captions = generate_captions(prompt_caption, num_captions=1) | |
end_bloom = time.time() | |
caption_scores, sorted_captions = sorting_texts(image_features, generated_captions) | |
output_dict = {} | |
output_dict['inference_time'] = {'CLIP inference': end_zeroshot - start_time, | |
'BLOOM request': end_bloom - end_zeroshot} | |
output_dict['generated_captions'] = postprocess_results(caption_scores, sorted_captions) | |
output_dict['reasoning'] = {'openimage_results': postprocess_results(openimage_scores, openimage_classes), | |
'tencentml_results': postprocess_results(tencentml_scores, tencentml_classes), | |
'place365_results': postprocess_results(place365_scores, place365_classes), | |
'imgtype_results': postprocess_results(imgtype_scores, imgtype_classes), | |
'ppl_results': postprocess_results(ppl_scores, ppl_classes), | |
'ifppl_results': postprocess_results(ifppl_scores, ifppl_classes)} | |
return output_dict | |
if __name__ == '__main__': | |
print('\tinit models') | |
global model_dict | |
model_dict = load_models() | |
# define gradio demo | |
inputs = [gr.inputs.Image(type="pil", label="Image") | |
] | |
outputs = gr.outputs.JSON() | |
title = "Socratic models for image captioning with BLOOM" | |
description = """ | |
## Details | |
**Without any fine-tuning**, we can do image captioning using Visual-Language models (e.g., CLIP, SLIP, ...) and Large language models (e.g., GPT, BLOOM, ...). | |
In this demo, I choose BLOOM as the language model and CLIP ViT-L/14 as the visual-language model. | |
The order of generating image caption is as follow: | |
1. Classify whether there are people, where the location is, and what objects are in the input image using the visual-language model. | |
2. Then, build a prompt using classified results. | |
3. Request BLOOM API with the prompt. | |
This demo is slightly different with the original method proposed in the socratic model paper. | |
I used not only tencent ml class names, but also OpenImage class names and I adopt BLOOM for the large language model | |
If you want the demo using GPT3 from OpenAI, check https://github.com/geonm/socratic-models-demo. | |
Demo is running on CPU. | |
""" | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.00598'>Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language</a></p>" | |
examples = ['k21-1.jpg'] | |
gr.Interface(image_captioning, | |
inputs, | |
outputs, | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
#examples_per_page=50, | |
).launch() | |