|
import os |
|
import torch |
|
import gradio as gr |
|
import time |
|
import clip |
|
import requests |
|
import csv |
|
import json |
|
import wget |
|
|
|
url_dict = {'clip_ViTL14_openimage_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_openimage_classifier_weights.pt', |
|
'clip_ViTL14_place365_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_place365_classifier_weights.pt', |
|
'clip_ViTL14_tencentml_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_tencentml_classifier_weights.pt'} |
|
|
|
os.makedirs('./prompts', exist_ok=True) |
|
for k, v in url_dict.items(): |
|
wget.download(v, out='./prompts') |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '' |
|
|
|
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" |
|
HF_TOKEN = os.environ["HF_TOKEN"] |
|
|
|
def load_openimage_classnames(csv_path): |
|
csv_data = open(csv_path) |
|
csv_reader = csv.reader(csv_data) |
|
classnames = {idx: row[-1] for idx, row in enumerate(csv_reader)} |
|
return classnames |
|
|
|
|
|
def load_tencentml_classnames(txt_path): |
|
txt_data = open(txt_path) |
|
lines = txt_data.readlines() |
|
classnames = {idx: line.strip() for idx, line in enumerate(lines)} |
|
return classnames |
|
|
|
|
|
def build_simple_classifier(clip_model, text_list, template, device): |
|
with torch.no_grad(): |
|
texts = [template(text) for text in text_list] |
|
text_inputs = clip.tokenize(texts).to(device) |
|
text_features = clip_model.encode_text(text_inputs) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
return text_features, {idx: text for idx, text in enumerate(text_list)} |
|
|
|
|
|
def load_models(): |
|
|
|
model_dict = {} |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print('\tLoading CLIP ViT-L/14') |
|
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device) |
|
print('\tLoading precomputed zeroshot classifier') |
|
openimage_classifier_weights = torch.load('./prompts/clip_ViTL14_openimage_classifier_weights.pt', map_location=device).type(torch.FloatTensor) |
|
openimage_classnames = load_openimage_classnames('./prompts/openimage-classnames.csv') |
|
tencentml_classifier_weights = torch.load('./prompts/clip_ViTL14_tencentml_classifier_weights.pt', map_location=device).type(torch.FloatTensor) |
|
tencentml_classnames = load_tencentml_classnames('./prompts/tencent-ml-classnames.txt') |
|
place365_classifier_weights = torch.load('./prompts/clip_ViTL14_place365_classifier_weights.pt', map_location=device).type(torch.FloatTensor) |
|
place365_classnames = load_tencentml_classnames('./prompts/place365-classnames.txt') |
|
|
|
print('\tBuilding simple zeroshot classifier') |
|
img_types = ['photo', 'cartoon', 'sketch', 'painting'] |
|
ppl_texts = ['no people', 'people'] |
|
ifppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people'] |
|
imgtype_classifier_weights, imgtype_classnames = build_simple_classifier(clip_model, img_types, lambda c: f'This is a {c}.', device) |
|
ppl_classifier_weights, ppl_classnames = build_simple_classifier(clip_model, ppl_texts, lambda c: f'There are {c} in this photo.', device) |
|
ifppl_classifier_weights, ifppl_classnames = build_simple_classifier(clip_model, ifppl_texts, lambda c: f'There {c} in this photo.', device) |
|
|
|
model_dict['clip_model'] = clip_model |
|
model_dict['clip_preprocess'] = clip_preprocess |
|
model_dict['openimage_classifier_weights'] = openimage_classifier_weights |
|
model_dict['openimage_classnames'] = openimage_classnames |
|
model_dict['tencentml_classifier_weights'] = tencentml_classifier_weights |
|
model_dict['tencentml_classnames'] = tencentml_classnames |
|
model_dict['place365_classifier_weights'] = place365_classifier_weights |
|
model_dict['place365_classnames'] = place365_classnames |
|
model_dict['imgtype_classifier_weights'] = imgtype_classifier_weights |
|
model_dict['imgtype_classnames'] = imgtype_classnames |
|
model_dict['ppl_classifier_weights'] = ppl_classifier_weights |
|
model_dict['ppl_classnames'] = ppl_classnames |
|
model_dict['ifppl_classifier_weights'] = ifppl_classifier_weights |
|
model_dict['ifppl_classnames'] = ifppl_classnames |
|
model_dict['device'] = device |
|
|
|
return model_dict |
|
|
|
|
|
def drop_gpu(tensor): |
|
if torch.cuda.is_available(): |
|
return tensor.cpu().numpy() |
|
else: |
|
return tensor.numpy() |
|
|
|
|
|
def zeroshot_classifier(image): |
|
image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device']) |
|
with torch.no_grad(): |
|
image_features = model_dict['clip_model'].encode_image(image_input) |
|
image_features /= image_features.norm(dim=-1, keepdim=True) |
|
|
|
sim = (100.0 * image_features @ model_dict['openimage_classifier_weights'].T).softmax(dim=-1) |
|
openimage_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] |
|
openimage_classes = [model_dict['openimage_classnames'][idx] for idx in indices] |
|
|
|
sim = (100.0 * image_features @ model_dict['tencentml_classifier_weights'].T).softmax(dim=-1) |
|
tencentml_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] |
|
tencentml_classes = [model_dict['tencentml_classnames'][idx] for idx in indices] |
|
|
|
sim = (100.0 * image_features @ model_dict['place365_classifier_weights'].T).softmax(dim=-1) |
|
place365_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)] |
|
place365_classes = [model_dict['place365_classnames'][idx] for idx in indices] |
|
|
|
sim = (100.0 * image_features @ model_dict['imgtype_classifier_weights'].T).softmax(dim=-1) |
|
imgtype_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['imgtype_classnames']))] |
|
imgtype_classes = [model_dict['imgtype_classnames'][idx] for idx in indices] |
|
|
|
sim = (100.0 * image_features @ model_dict['ppl_classifier_weights'].T).softmax(dim=-1) |
|
ppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ppl_classnames']))] |
|
ppl_classes = [model_dict['ppl_classnames'][idx] for idx in indices] |
|
|
|
sim = (100.0 * image_features @ model_dict['ifppl_classifier_weights'].T).softmax(dim=-1) |
|
ifppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ifppl_classnames']))] |
|
ifppl_classes = [model_dict['ifppl_classnames'][idx] for idx in indices] |
|
|
|
return image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes,\ |
|
place365_scores, place365_classes, imgtype_scores, imgtype_classes,\ |
|
ppl_scores, ppl_classes, ifppl_scores, ifppl_classes |
|
|
|
|
|
def generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes): |
|
img_type = imgtype_classes[0] |
|
ppl_result = ppl_classes[0] |
|
if ppl_result == 'people': |
|
ppl_result = ifppl_classes[0] |
|
else: |
|
ppl_result = 'are %s' % ppl_result |
|
|
|
sorted_places = place365_classes |
|
|
|
object_list = '' |
|
for cls in tencentml_classes: |
|
object_list += f'{cls}, ' |
|
for cls in openimage_classes[:2]: |
|
object_list += f'{cls}, ' |
|
object_list = object_list[:-2] |
|
|
|
prompt_caption = f'''I am an intelligent image captioning bot. |
|
This image is a {img_type}. There {ppl_result}. |
|
I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}. |
|
I think there might be a {object_list} in this {img_type}. |
|
A creative short caption I can generate to describe this image is:''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return prompt_caption |
|
|
|
|
|
def generate_captions(prompt, num_captions=3): |
|
headers = {"Authorization": f"Bearer {HF_TOKEN}"} |
|
|
|
max_length = 16 |
|
seed = 42 |
|
sample_or_greedy = 'Greedy' |
|
input_sentence = prompt |
|
if sample_or_greedy == "Sample": |
|
parameters = { |
|
"max_new_tokens": max_length, |
|
"top_p": 0.7, |
|
"do_sample": True, |
|
"seed": seed, |
|
"early_stopping": False, |
|
"length_penalty": 0.0, |
|
"eos_token_id": None, |
|
} |
|
else: |
|
parameters = { |
|
"max_new_tokens": max_length, |
|
"do_sample": False, |
|
"seed": seed, |
|
"early_stopping": False, |
|
"length_penalty": 0.0, |
|
"eos_token_id": None, |
|
} |
|
|
|
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False}} |
|
|
|
bloom_results = [] |
|
for _ in range(num_captions): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
output = response.json() |
|
generated_text = output[0]['generated_text'].replace(prompt, '').split('.')[0] + '.' |
|
bloom_results.append(generated_text) |
|
return bloom_results |
|
|
|
|
|
def sorting_texts(image_features, captions): |
|
with torch.no_grad(): |
|
text_inputs = clip.tokenize(captions).to(model_dict['device']) |
|
text_features = model_dict['clip_model'].encode_text(text_inputs) |
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
|
sim = (100.0 * image_features @ text_features.T).softmax(dim=-1) |
|
scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(captions))] |
|
sorted_captions = [captions[idx] for idx in indices] |
|
|
|
return scores, sorted_captions |
|
|
|
|
|
def postprocess_results(scores, classes): |
|
scores = [float('%.4f' % float(val)) for val in scores] |
|
outputs = [] |
|
for score, cls in zip(scores, classes): |
|
outputs.append({'score': score, 'output': cls}) |
|
return outputs |
|
|
|
|
|
def image_captioning(image): |
|
start_time = time.time() |
|
image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes, place365_scores, place365_classes, imgtype_scores, imgtype_classes, ppl_scores, ppl_classes, ifppl_scores, ifppl_classes = zeroshot_classifier(image) |
|
end_zeroshot = time.time() |
|
prompt_caption = generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes) |
|
generated_captions = generate_captions(prompt_caption, num_captions=1) |
|
end_bloom = time.time() |
|
caption_scores, sorted_captions = sorting_texts(image_features, generated_captions) |
|
|
|
output_dict = {} |
|
output_dict['inference_time'] = {'CLIP inference': end_zeroshot - start_time, |
|
'BLOOM request': end_bloom - end_zeroshot} |
|
|
|
output_dict['generated_captions'] = postprocess_results(caption_scores, sorted_captions) |
|
output_dict['reasoning'] = {'openimage_results': postprocess_results(openimage_scores, openimage_classes), |
|
'tencentml_results': postprocess_results(tencentml_scores, tencentml_classes), |
|
'place365_results': postprocess_results(place365_scores, place365_classes), |
|
'imgtype_results': postprocess_results(imgtype_scores, imgtype_classes), |
|
'ppl_results': postprocess_results(ppl_scores, ppl_classes), |
|
'ifppl_results': postprocess_results(ifppl_scores, ifppl_classes)} |
|
return output_dict |
|
|
|
|
|
if __name__ == '__main__': |
|
print('\tinit models') |
|
|
|
global model_dict |
|
|
|
model_dict = load_models() |
|
|
|
|
|
inputs = [gr.inputs.Image(type="pil", label="Image") |
|
] |
|
|
|
outputs = gr.outputs.JSON() |
|
|
|
title = "Socratic models for image captioning with BLOOM" |
|
|
|
description = """ |
|
## Details |
|
**Without any fine-tuning**, we can do image captioning using Visual-Language models (e.g., CLIP, SLIP, ...) and Large language models (e.g., GPT, BLOOM, ...). |
|
In this demo, I choose BLOOM as the language model and CLIP ViT-L/14 as the visual-language model. |
|
The order of generating image caption is as follow: |
|
1. Classify whether there are people, where the location is, and what objects are in the input image using the visual-language model. |
|
2. Then, build a prompt using classified results. |
|
3. Request BLOOM API with the prompt. |
|
|
|
This demo is slightly different with the original method proposed in the socratic model paper. |
|
I used not only tencent ml class names, but also OpenImage class names and I adopt BLOOM for the large language model |
|
|
|
If you want the demo using GPT3 from OpenAI, check https://github.com/geonm/socratic-models-demo. |
|
|
|
Demo is running on CPU. |
|
""" |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.00598'>Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language</a></p>" |
|
examples = ['k21-1.jpg'] |
|
|
|
gr.Interface(image_captioning, |
|
inputs, |
|
outputs, |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
|
|
).launch() |
|
|