Geonmo's picture
fix typo
e839607
raw
history blame
13.4 kB
import os
import torch
import gradio as gr
import time
import clip
import requests
import csv
import json
import wget
url_dict = {'clip_ViTL14_openimage_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_openimage_classifier_weights.pt',
'clip_ViTL14_place365_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_place365_classifier_weights.pt',
'clip_ViTL14_tencentml_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_tencentml_classifier_weights.pt'}
os.makedirs('./prompts', exist_ok=True)
for k, v in url_dict.items():
wget.download(v, out='./prompts')
os.environ['CUDA_VISIBLE_DEVICES'] = ''
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
HF_TOKEN = os.environ["HF_TOKEN"]
def load_openimage_classnames(csv_path):
csv_data = open(csv_path)
csv_reader = csv.reader(csv_data)
classnames = {idx: row[-1] for idx, row in enumerate(csv_reader)}
return classnames
def load_tencentml_classnames(txt_path):
txt_data = open(txt_path)
lines = txt_data.readlines()
classnames = {idx: line.strip() for idx, line in enumerate(lines)}
return classnames
def build_simple_classifier(clip_model, text_list, template, device):
with torch.no_grad():
texts = [template(text) for text in text_list]
text_inputs = clip.tokenize(texts).to(device)
text_features = clip_model.encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features, {idx: text for idx, text in enumerate(text_list)}
def load_models():
# build model and tokenizer
model_dict = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
print('\tLoading CLIP ViT-L/14')
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
print('\tLoading precomputed zeroshot classifier')
openimage_classifier_weights = torch.load('./prompts/clip_ViTL14_openimage_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
openimage_classnames = load_openimage_classnames('./prompts/openimage-classnames.csv')
tencentml_classifier_weights = torch.load('./prompts/clip_ViTL14_tencentml_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
tencentml_classnames = load_tencentml_classnames('./prompts/tencent-ml-classnames.txt')
place365_classifier_weights = torch.load('./prompts/clip_ViTL14_place365_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
place365_classnames = load_tencentml_classnames('./prompts/place365-classnames.txt')
print('\tBuilding simple zeroshot classifier')
img_types = ['photo', 'cartoon', 'sketch', 'painting']
ppl_texts = ['no people', 'people']
ifppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people']
imgtype_classifier_weights, imgtype_classnames = build_simple_classifier(clip_model, img_types, lambda c: f'This is a {c}.', device)
ppl_classifier_weights, ppl_classnames = build_simple_classifier(clip_model, ppl_texts, lambda c: f'There are {c} in this photo.', device)
ifppl_classifier_weights, ifppl_classnames = build_simple_classifier(clip_model, ifppl_texts, lambda c: f'There {c} in this photo.', device)
model_dict['clip_model'] = clip_model
model_dict['clip_preprocess'] = clip_preprocess
model_dict['openimage_classifier_weights'] = openimage_classifier_weights
model_dict['openimage_classnames'] = openimage_classnames
model_dict['tencentml_classifier_weights'] = tencentml_classifier_weights
model_dict['tencentml_classnames'] = tencentml_classnames
model_dict['place365_classifier_weights'] = place365_classifier_weights
model_dict['place365_classnames'] = place365_classnames
model_dict['imgtype_classifier_weights'] = imgtype_classifier_weights
model_dict['imgtype_classnames'] = imgtype_classnames
model_dict['ppl_classifier_weights'] = ppl_classifier_weights
model_dict['ppl_classnames'] = ppl_classnames
model_dict['ifppl_classifier_weights'] = ifppl_classifier_weights
model_dict['ifppl_classnames'] = ifppl_classnames
model_dict['device'] = device
return model_dict
def drop_gpu(tensor):
if torch.cuda.is_available():
return tensor.cpu().numpy()
else:
return tensor.numpy()
def zeroshot_classifier(image):
image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device'])
with torch.no_grad():
image_features = model_dict['clip_model'].encode_image(image_input)
image_features /= image_features.norm(dim=-1, keepdim=True)
sim = (100.0 * image_features @ model_dict['openimage_classifier_weights'].T).softmax(dim=-1)
openimage_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
openimage_classes = [model_dict['openimage_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['tencentml_classifier_weights'].T).softmax(dim=-1)
tencentml_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
tencentml_classes = [model_dict['tencentml_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['place365_classifier_weights'].T).softmax(dim=-1)
place365_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
place365_classes = [model_dict['place365_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['imgtype_classifier_weights'].T).softmax(dim=-1)
imgtype_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['imgtype_classnames']))]
imgtype_classes = [model_dict['imgtype_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['ppl_classifier_weights'].T).softmax(dim=-1)
ppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ppl_classnames']))]
ppl_classes = [model_dict['ppl_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['ifppl_classifier_weights'].T).softmax(dim=-1)
ifppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ifppl_classnames']))]
ifppl_classes = [model_dict['ifppl_classnames'][idx] for idx in indices]
return image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes,\
place365_scores, place365_classes, imgtype_scores, imgtype_classes,\
ppl_scores, ppl_classes, ifppl_scores, ifppl_classes
def generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes):
img_type = imgtype_classes[0]
ppl_result = ppl_classes[0]
if ppl_result == 'people':
ppl_result = ifppl_classes[0]
else:
ppl_result = 'are %s' % ppl_result
sorted_places = place365_classes
object_list = ''
for cls in tencentml_classes:
object_list += f'{cls}, '
for cls in openimage_classes[:2]:
object_list += f'{cls}, '
object_list = object_list[:-2]
prompt_caption = f'''I am an intelligent image captioning bot.
This image is a {img_type}. There {ppl_result}.
I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
I think there might be a {object_list} in this {img_type}.
A creative short caption I can generate to describe this image is:'''
#prompt_search = f'''Let's list keywords that include the following description.
#This image is a {img_type}. There {ppl_result}.
#I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
#I think there might be a {object_list} in this {img_type}.
#Relevant keywords which we can list and are seperated with comma are:'''
return prompt_caption
def generate_captions(prompt, num_captions=3):
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
max_length = 16
seed = 42
sample_or_greedy = 'Greedy'
input_sentence = prompt
if sample_or_greedy == "Sample":
parameters = {
"max_new_tokens": max_length,
"top_p": 0.7,
"do_sample": True,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
else:
parameters = {
"max_new_tokens": max_length,
"do_sample": False,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False}}
bloom_results = []
for _ in range(num_captions):
response = requests.post(API_URL, headers=headers, json=payload)
output = response.json()
generated_text = output[0]['generated_text'].replace(prompt, '').split('.')[0] + '.'
bloom_results.append(generated_text)
return bloom_results
def sorting_texts(image_features, captions):
with torch.no_grad():
text_inputs = clip.tokenize(captions).to(model_dict['device'])
text_features = model_dict['clip_model'].encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
sim = (100.0 * image_features @ text_features.T).softmax(dim=-1)
scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(captions))]
sorted_captions = [captions[idx] for idx in indices]
return scores, sorted_captions
def postprocess_results(scores, classes):
scores = [float('%.4f' % float(val)) for val in scores]
outputs = []
for score, cls in zip(scores, classes):
outputs.append({'score': score, 'output': cls})
return outputs
def image_captioning(image):
start_time = time.time()
image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes, place365_scores, place365_classes, imgtype_scores, imgtype_classes, ppl_scores, ppl_classes, ifppl_scores, ifppl_classes = zeroshot_classifier(image)
end_zeroshot = time.time()
prompt_caption = generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes)
generated_captions = generate_captions(prompt_caption, num_captions=1)
end_bloom = time.time()
caption_scores, sorted_captions = sorting_texts(image_features, generated_captions)
output_dict = {}
output_dict['inference_time'] = {'CLIP inference': end_zeroshot - start_time,
'BLOOM request': end_bloom - end_zeroshot}
output_dict['generated_captions'] = postprocess_results(caption_scores, sorted_captions)
output_dict['reasoning'] = {'openimage_results': postprocess_results(openimage_scores, openimage_classes),
'tencentml_results': postprocess_results(tencentml_scores, tencentml_classes),
'place365_results': postprocess_results(place365_scores, place365_classes),
'imgtype_results': postprocess_results(imgtype_scores, imgtype_classes),
'ppl_results': postprocess_results(ppl_scores, ppl_classes),
'ifppl_results': postprocess_results(ifppl_scores, ifppl_classes)}
return output_dict
if __name__ == '__main__':
print('\tinit models')
global model_dict
model_dict = load_models()
# define gradio demo
inputs = [gr.inputs.Image(type="pil", label="Image")
]
outputs = gr.outputs.JSON()
title = "Socratic models for image captioning with BLOOM"
description = """
## Details
**Without any fine-tuning**, we can do image captioning using Visual-Language models (e.g., CLIP, SLIP, ...) and Large language models (e.g., GPT, BLOOM, ...).
In this demo, I choose BLOOM as the language model and CLIP ViT-L/14 as the visual-language model.
The order of generating image caption is as follow:
1. Classify whether there are people, where the location is, and what objects are in the input image using the visual-language model.
2. Then, build a prompt using classified results.
3. Request BLOOM API with the prompt.
This demo is slightly different with the original method proposed in the socratic model paper.
I used not only tencent ml class names, but also OpenImage class names and I adopt BLOOM for the large language model
If you want the demo using GPT3 from OpenAI, check https://github.com/geonm/socratic-models-demo.
Demo is running on CPU.
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.00598'>Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language</a></p>"
examples = ['k21-1.jpg']
gr.Interface(image_captioning,
inputs,
outputs,
title=title,
description=description,
article=article,
examples=examples,
#examples_per_page=50,
).launch()