Geonmo's picture
fix typo
e839607
raw history blame
No virus
13.4 kB
import os
import torch
import gradio as gr
import time
import clip
import requests
import csv
import json
import wget
url_dict = {'clip_ViTL14_openimage_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_openimage_classifier_weights.pt',
'clip_ViTL14_place365_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_place365_classifier_weights.pt',
'clip_ViTL14_tencentml_classifier_weights.pt': 'https://raw.githubusercontent.com/geonm/socratic-models-demo/master/prompts/clip_ViTL14_tencentml_classifier_weights.pt'}
os.makedirs('./prompts', exist_ok=True)
for k, v in url_dict.items():
wget.download(v, out='./prompts')
os.environ['CUDA_VISIBLE_DEVICES'] = ''
API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom"
HF_TOKEN = os.environ["HF_TOKEN"]
def load_openimage_classnames(csv_path):
csv_data = open(csv_path)
csv_reader = csv.reader(csv_data)
classnames = {idx: row[-1] for idx, row in enumerate(csv_reader)}
return classnames
def load_tencentml_classnames(txt_path):
txt_data = open(txt_path)
lines = txt_data.readlines()
classnames = {idx: line.strip() for idx, line in enumerate(lines)}
return classnames
def build_simple_classifier(clip_model, text_list, template, device):
with torch.no_grad():
texts = [template(text) for text in text_list]
text_inputs = clip.tokenize(texts).to(device)
text_features = clip_model.encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features, {idx: text for idx, text in enumerate(text_list)}
def load_models():
# build model and tokenizer
model_dict = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
print('\tLoading CLIP ViT-L/14')
clip_model, clip_preprocess = clip.load("ViT-L/14", device=device)
print('\tLoading precomputed zeroshot classifier')
openimage_classifier_weights = torch.load('./prompts/clip_ViTL14_openimage_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
openimage_classnames = load_openimage_classnames('./prompts/openimage-classnames.csv')
tencentml_classifier_weights = torch.load('./prompts/clip_ViTL14_tencentml_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
tencentml_classnames = load_tencentml_classnames('./prompts/tencent-ml-classnames.txt')
place365_classifier_weights = torch.load('./prompts/clip_ViTL14_place365_classifier_weights.pt', map_location=device).type(torch.FloatTensor)
place365_classnames = load_tencentml_classnames('./prompts/place365-classnames.txt')
print('\tBuilding simple zeroshot classifier')
img_types = ['photo', 'cartoon', 'sketch', 'painting']
ppl_texts = ['no people', 'people']
ifppl_texts = ['is one person', 'are two people', 'are three people', 'are several people', 'are many people']
imgtype_classifier_weights, imgtype_classnames = build_simple_classifier(clip_model, img_types, lambda c: f'This is a {c}.', device)
ppl_classifier_weights, ppl_classnames = build_simple_classifier(clip_model, ppl_texts, lambda c: f'There are {c} in this photo.', device)
ifppl_classifier_weights, ifppl_classnames = build_simple_classifier(clip_model, ifppl_texts, lambda c: f'There {c} in this photo.', device)
model_dict['clip_model'] = clip_model
model_dict['clip_preprocess'] = clip_preprocess
model_dict['openimage_classifier_weights'] = openimage_classifier_weights
model_dict['openimage_classnames'] = openimage_classnames
model_dict['tencentml_classifier_weights'] = tencentml_classifier_weights
model_dict['tencentml_classnames'] = tencentml_classnames
model_dict['place365_classifier_weights'] = place365_classifier_weights
model_dict['place365_classnames'] = place365_classnames
model_dict['imgtype_classifier_weights'] = imgtype_classifier_weights
model_dict['imgtype_classnames'] = imgtype_classnames
model_dict['ppl_classifier_weights'] = ppl_classifier_weights
model_dict['ppl_classnames'] = ppl_classnames
model_dict['ifppl_classifier_weights'] = ifppl_classifier_weights
model_dict['ifppl_classnames'] = ifppl_classnames
model_dict['device'] = device
return model_dict
def drop_gpu(tensor):
if torch.cuda.is_available():
return tensor.cpu().numpy()
else:
return tensor.numpy()
def zeroshot_classifier(image):
image_input = model_dict['clip_preprocess'](image).unsqueeze(0).to(model_dict['device'])
with torch.no_grad():
image_features = model_dict['clip_model'].encode_image(image_input)
image_features /= image_features.norm(dim=-1, keepdim=True)
sim = (100.0 * image_features @ model_dict['openimage_classifier_weights'].T).softmax(dim=-1)
openimage_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
openimage_classes = [model_dict['openimage_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['tencentml_classifier_weights'].T).softmax(dim=-1)
tencentml_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
tencentml_classes = [model_dict['tencentml_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['place365_classifier_weights'].T).softmax(dim=-1)
place365_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(10)]
place365_classes = [model_dict['place365_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['imgtype_classifier_weights'].T).softmax(dim=-1)
imgtype_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['imgtype_classnames']))]
imgtype_classes = [model_dict['imgtype_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['ppl_classifier_weights'].T).softmax(dim=-1)
ppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ppl_classnames']))]
ppl_classes = [model_dict['ppl_classnames'][idx] for idx in indices]
sim = (100.0 * image_features @ model_dict['ifppl_classifier_weights'].T).softmax(dim=-1)
ifppl_scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(model_dict['ifppl_classnames']))]
ifppl_classes = [model_dict['ifppl_classnames'][idx] for idx in indices]
return image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes,\
place365_scores, place365_classes, imgtype_scores, imgtype_classes,\
ppl_scores, ppl_classes, ifppl_scores, ifppl_classes
def generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes):
img_type = imgtype_classes[0]
ppl_result = ppl_classes[0]
if ppl_result == 'people':
ppl_result = ifppl_classes[0]
else:
ppl_result = 'are %s' % ppl_result
sorted_places = place365_classes
object_list = ''
for cls in tencentml_classes:
object_list += f'{cls}, '
for cls in openimage_classes[:2]:
object_list += f'{cls}, '
object_list = object_list[:-2]
prompt_caption = f'''I am an intelligent image captioning bot.
This image is a {img_type}. There {ppl_result}.
I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
I think there might be a {object_list} in this {img_type}.
A creative short caption I can generate to describe this image is:'''
#prompt_search = f'''Let's list keywords that include the following description.
#This image is a {img_type}. There {ppl_result}.
#I think this photo was taken at a {sorted_places[0]}, {sorted_places[1]}, or {sorted_places[2]}.
#I think there might be a {object_list} in this {img_type}.
#Relevant keywords which we can list and are seperated with comma are:'''
return prompt_caption
def generate_captions(prompt, num_captions=3):
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
max_length = 16
seed = 42
sample_or_greedy = 'Greedy'
input_sentence = prompt
if sample_or_greedy == "Sample":
parameters = {
"max_new_tokens": max_length,
"top_p": 0.7,
"do_sample": True,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
else:
parameters = {
"max_new_tokens": max_length,
"do_sample": False,
"seed": seed,
"early_stopping": False,
"length_penalty": 0.0,
"eos_token_id": None,
}
payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False}}
bloom_results = []
for _ in range(num_captions):
response = requests.post(API_URL, headers=headers, json=payload)
output = response.json()
generated_text = output[0]['generated_text'].replace(prompt, '').split('.')[0] + '.'
bloom_results.append(generated_text)
return bloom_results
def sorting_texts(image_features, captions):
with torch.no_grad():
text_inputs = clip.tokenize(captions).to(model_dict['device'])
text_features = model_dict['clip_model'].encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
sim = (100.0 * image_features @ text_features.T).softmax(dim=-1)
scores, indices = [drop_gpu(tensor) for tensor in sim[0].topk(len(captions))]
sorted_captions = [captions[idx] for idx in indices]
return scores, sorted_captions
def postprocess_results(scores, classes):
scores = [float('%.4f' % float(val)) for val in scores]
outputs = []
for score, cls in zip(scores, classes):
outputs.append({'score': score, 'output': cls})
return outputs
def image_captioning(image):
start_time = time.time()
image_features, openimage_scores, openimage_classes, tencentml_scores, tencentml_classes, place365_scores, place365_classes, imgtype_scores, imgtype_classes, ppl_scores, ppl_classes, ifppl_scores, ifppl_classes = zeroshot_classifier(image)
end_zeroshot = time.time()
prompt_caption = generate_prompt(openimage_classes, tencentml_classes, place365_classes, imgtype_classes, ppl_classes, ifppl_classes)
generated_captions = generate_captions(prompt_caption, num_captions=1)
end_bloom = time.time()
caption_scores, sorted_captions = sorting_texts(image_features, generated_captions)
output_dict = {}
output_dict['inference_time'] = {'CLIP inference': end_zeroshot - start_time,
'BLOOM request': end_bloom - end_zeroshot}
output_dict['generated_captions'] = postprocess_results(caption_scores, sorted_captions)
output_dict['reasoning'] = {'openimage_results': postprocess_results(openimage_scores, openimage_classes),
'tencentml_results': postprocess_results(tencentml_scores, tencentml_classes),
'place365_results': postprocess_results(place365_scores, place365_classes),
'imgtype_results': postprocess_results(imgtype_scores, imgtype_classes),
'ppl_results': postprocess_results(ppl_scores, ppl_classes),
'ifppl_results': postprocess_results(ifppl_scores, ifppl_classes)}
return output_dict
if __name__ == '__main__':
print('\tinit models')
global model_dict
model_dict = load_models()
# define gradio demo
inputs = [gr.inputs.Image(type="pil", label="Image")
]
outputs = gr.outputs.JSON()
title = "Socratic models for image captioning with BLOOM"
description = """
## Details
**Without any fine-tuning**, we can do image captioning using Visual-Language models (e.g., CLIP, SLIP, ...) and Large language models (e.g., GPT, BLOOM, ...).
In this demo, I choose BLOOM as the language model and CLIP ViT-L/14 as the visual-language model.
The order of generating image caption is as follow:
1. Classify whether there are people, where the location is, and what objects are in the input image using the visual-language model.
2. Then, build a prompt using classified results.
3. Request BLOOM API with the prompt.
This demo is slightly different with the original method proposed in the socratic model paper.
I used not only tencent ml class names, but also OpenImage class names and I adopt BLOOM for the large language model
If you want the demo using GPT3 from OpenAI, check https://github.com/geonm/socratic-models-demo.
Demo is running on CPU.
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2204.00598'>Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language</a></p>"
examples = ['k21-1.jpg']
gr.Interface(image_captioning,
inputs,
outputs,
title=title,
description=description,
article=article,
examples=examples,
#examples_per_page=50,
).launch()