Spaces:
Runtime error
Runtime error
import os | |
os.system("cd open_flamingo && pip install .") | |
os.system("cd transformers && pip install .") | |
import numpy as np | |
import torch | |
from PIL import Image | |
from open_flamingo.train.distributed import init_distributed_device, world_info_from_env | |
import string | |
import cv2 | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from huggingface_hub import hf_hub_download, login | |
from open_flamingo.src.factory import create_model_and_transforms | |
flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms( | |
"ViT-L-14", | |
"datacomp_xl_s13b_b90k", | |
"facebook/opt-350m", | |
"facebook/opt-350m", | |
add_visual_grounding=True, | |
location_token_num=1000, | |
add_visual_token = True, | |
use_format_v2 = True, | |
) | |
checkpoint_path = hf_hub_download("chendl/mm", "checkpoint_opt350m_v2.pt") | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
model_state_dict = {} | |
for key in checkpoint.keys(): | |
model_state_dict[key.replace("module.", "")] = checkpoint[key] | |
if "vision_encoder.logit_scale"in model_state_dict: | |
# previous checkpoint has some unnecessary weights | |
del model_state_dict["vision_encoder.logit_scale"] | |
del model_state_dict["vision_encoder.visual.proj"] | |
del model_state_dict["vision_encoder.visual.ln_post.weight"] | |
del model_state_dict["vision_encoder.visual.ln_post.bias"] | |
flamingo.load_state_dict(model_state_dict, strict=True) | |
def get_outputs( | |
model, | |
batch_images, | |
attention_mask, | |
max_generation_length, | |
min_generation_length, | |
num_beams, | |
length_penalty, | |
input_ids, | |
image_start_index_list=None, | |
image_nums=None, | |
bad_words_ids=None, | |
): | |
# and torch.cuda.amp.autocast(dtype=torch.float16) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
batch_images, | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_generation_length, | |
min_length=min_generation_length, | |
num_beams=num_beams, | |
length_penalty=length_penalty, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
bad_words_ids=bad_words_ids, | |
) | |
outputs = outputs[:, len(input_ids[0]) :] | |
return outputs | |
def generate( | |
idx, | |
image, | |
text, | |
vis_embed_size=256, | |
rank=0, | |
world_size=1, | |
): | |
if image is None: | |
raise gr.Error("Please upload an image.") | |
flamingo.eval() | |
loc_token_ids = [] | |
for i in range(1000): | |
loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1])) | |
media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] | |
endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
all_ids = set(range(flamingo.lang_encoder.lm_head.out_features)) | |
bad_words_ids = list(all_ids - set(loc_token_ids)) | |
bad_words_ids = [[b] for b in bad_words_ids] | |
loc_word_ids = list(set(loc_token_ids)) | |
loc_word_ids = [[b] for b in loc_word_ids] | |
min_loc_token_id = min(loc_token_ids) | |
max_loc_token_id = max(loc_token_ids) | |
image_ori = image | |
image = image.convert("RGB") | |
width = image.width | |
height = image.height | |
image = image.resize((224, 224)) | |
batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
if idx == 1: | |
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"] | |
bad_words_ids = None | |
max_generation_length = 5 | |
else: | |
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"] | |
bad_words_ids = loc_word_ids | |
max_generation_length = 100 | |
encodings = tokenizer( | |
prompt, | |
padding="longest", | |
truncation=True, | |
return_tensors="pt", | |
max_length=2000, | |
) | |
input_ids = encodings["input_ids"] | |
attention_mask = encodings["attention_mask"] | |
image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
image_start_index_list = [[x] for x in image_start_index_list] | |
image_nums = [1] * len(input_ids) | |
outputs = get_outputs( | |
model=flamingo, | |
batch_images=batch_images, | |
attention_mask=attention_mask, | |
max_generation_length=max_generation_length, | |
min_generation_length=4, | |
num_beams=1, | |
length_penalty=1.0, | |
input_ids=input_ids, | |
bad_words_ids=bad_words_ids, | |
image_start_index_list=image_start_index_list, | |
image_nums=image_nums, | |
) | |
box = [] | |
out_image = image_ori | |
for o in outputs[0]: | |
if o >= min_loc_token_id and o <= max_loc_token_id: | |
box.append(o.item() - min_loc_token_id) | |
if len(box) == 4: | |
break | |
# else: | |
# tqdm.write(f"output: {tokenizer.batch_decode(outputs)}") | |
# tqdm.write(f"prompt: {prompt}") | |
if len(box) == 4: | |
img = cv2.cvtColor(np.array(image_ori), cv2.COLOR_RGB2BGR) | |
out = cv2.rectangle(img, (int(box[0] * width / 1000), int(box[1] * height / 1000)), | |
(int(box[2] * width / 1000), int(box[3] * height / 1000)), color=(255, 0, 255), thickness=2) | |
out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) | |
out_image = Image.fromarray(out) | |
# else: | |
# tqdm.write(f"output: {tokenizer.batch_decode(outputs)}") | |
# tqdm.write(f"prompt: {prompt}") | |
gen_text = tokenizer.batch_decode(outputs) | |
if idx == 1: | |
return f"Output:{gen_text}", out_image | |
elif idx == 2: | |
return (f"Question: {text.strip()} Answer: {gen_text}") | |
else: | |
return (f"Output:{gen_text}") | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
🍜 Object Centric Pretraining Demo | |
In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience. | |
The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text. | |
""" | |
) | |
with gr.Accordion("See terms and conditions"): | |
gr.Markdown( | |
"""**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""") | |
with gr.Tab("📷 Image Captioning"): | |
with gr.Row(): | |
query_image = gr.Image(type="pil") | |
with gr.Row(): | |
chat_input = gr.Textbox(lines=1, label="Chat Input") | |
text_output = gr.Textbox(value="Output:", label="Model output") | |
run_btn = gr.Button("Run model") | |
def on_click_fn(img,text): return generate(0, img, text) | |
run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output]) | |
with gr.Tab("🦓 Grounding"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query_image = gr.Image(type="pil") | |
with gr.Column(scale=1): | |
out_image = gr.Image(type="pil") | |
with gr.Row(): | |
chat_input = gr.Textbox(lines=1, label="Chat Input") | |
text_output = gr.Textbox(value="Output:", label="Model output") | |
run_btn = gr.Button("Run model") | |
def on_click_fn(img, text): return generate(1, img, text) | |
run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image]) | |
with gr.Tab("🔢 Counting objects"): | |
with gr.Row(): | |
query_image = gr.Image(type="pil") | |
with gr.Row(): | |
chat_input = gr.Textbox(lines=1, label="Chat Input") | |
text_output = gr.Textbox(value="Output:", label="Model output") | |
run_btn = gr.Button("Run model") | |
def on_click_fn(img,text): return generate(0, img, text) | |
run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output]) | |
with gr.Tab("🕵️ Visual Question Answering"): | |
with gr.Row(): | |
query_image = gr.Image(type="pil") | |
with gr.Row(): | |
question = gr.Textbox(lines=1, label="Question") | |
text_output = gr.Textbox(value="Output:", label="Model output") | |
run_btn = gr.Button("Run model") | |
def on_click_fn(img, txt): return generate(2, img, txt) | |
run_btn.click( | |
on_click_fn, inputs=[query_image, question], outputs=[text_output] | |
) | |
with gr.Tab("🌎 Custom"): | |
gr.Markdown( | |
"""### Customize the demonstration by uploading your own images and text samples. | |
### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**""" | |
) | |
with gr.Row(): | |
query_image = gr.Image(type="pil") | |
with gr.Row(): | |
question = gr.Textbox(lines=1, label="Question") | |
text_output = gr.Textbox(value="Output:", label="Model output") | |
run_btn = gr.Button("Run model") | |
def on_click_fn(img, txt): return generate(2, img, txt) | |
run_btn.click( | |
on_click_fn, inputs=[query_image, question], outputs=[text_output] | |
) | |
demo.queue(concurrency_count=1) | |
demo.launch() |