diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..0937192c9758f640ea61641884bf4fc87977e780 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +test_img/img18.jpg filter=lfs diff=lfs merge=lfs -text +test_img/img22.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/DejaVuSansCondensed-Bold.ttf b/DejaVuSansCondensed-Bold.ttf new file mode 100644 index 0000000000000000000000000000000000000000..437f2f5c0e5b44efcb98f31c1a279458b354fb83 Binary files /dev/null and b/DejaVuSansCondensed-Bold.ttf differ diff --git a/Image/demo1.svg b/Image/demo1.svg new file mode 100644 index 0000000000000000000000000000000000000000..2e81e5f1d350194acd34d72cbfc6908b2cbbe33d --- /dev/null +++ b/Image/demo1.svg @@ -0,0 +1 @@ +Click PromptthereisachairthatissittingonawoodenfloorEnglish, Positive,Facutal桌子前有一把椅子Chinese, Positive,FacutalthereisawhitecabinetwithabasketontopofitEnglish, Positive,FacutalTheblandwhitecabinetistoppedwithanunremarkablebasket.English, Negative,FacutalthereisaplantthatissittinginaplotonthefloorEnglish, Positive,FacutalAlovelyplantsitsinapotonthefloor,addingatouchofnaturetotheroom.English, Positive,Imagination \ No newline at end of file diff --git a/Image/demo2.svg b/Image/demo2.svg new file mode 100644 index 0000000000000000000000000000000000000000..ecf37926ee287c39590c4ea2ba9972edf32137f1 --- /dev/null +++ b/Image/demo2.svg @@ -0,0 +1 @@ +thereisabrownbearthatisstandingupwithitsheadturnedthereisabrownbearsittingonablacksurfacethereisabrownbearthatislayingdownonthegroundthereisalargebearandasmallbearsittingtogetherthereareagroupofbearsthataresittingtogether \ No newline at end of file diff --git a/Image/title.svg b/Image/title.svg new file mode 100644 index 0000000000000000000000000000000000000000..87fcc5fe890c431b9ea6488172b5539b6a959695 --- /dev/null +++ b/Image/title.svg @@ -0,0 +1 @@ +Captioning Anything \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..23f07469823f9e9eb0df3c19b788f1ec828c4928 --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +BSD 3-Clause License + +Copyright (c) 2023, Teng Wang + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/app_old.py b/app_old.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfdf1cb04ac3da31e1e46bbfea6a8b21eebd808 --- /dev/null +++ b/app_old.py @@ -0,0 +1,261 @@ +from io import BytesIO +import string +import gradio as gr +import requests +from caas import CaptionAnything +import torch +import json +import sys +import argparse +from caas import parse_augment +import os + +# download sam checkpoint if not downloaded +def download_checkpoint(url, folder, filename): + os.makedirs(folder, exist_ok=True) + filepath = os.path.join(folder, filename) + + if not os.path.exists(filepath): + response = requests.get(url, stream=True) + with open(filepath, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + return filepath +checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" +folder = "segmenter" +filename = "sam_vit_h_4b8939.pth" + +title = """

Caption-Anything

""" +description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. +
Code: GitHub repo: +""" + +examples = [ + ["test_img/img2.jpg", "[[1000, 700, 1]]"] +] + +args = parse_augment() + +def get_prompt(chat_input, click_state): + points = click_state[0] + labels = click_state[1] + inputs = json.loads(chat_input) + for input in inputs: + points.append(input[:2]) + labels.append(input[2]) + + prompt = { + "prompt_type":["click"], + "input_point":points, + "input_label":labels, + "multimask_output":"True", + } + return prompt + +def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state): + controls = {'length': length, + 'sentiment': sentiment, + 'factuality': factuality, + 'language': language} + prompt = get_prompt(chat_input, click_state) + print('prompt: ', prompt, 'controls: ', controls) + out = model.inference(image_input, prompt, controls) + state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))] + for k, v in out['generated_captions'].items(): + state = state + [(f'{k}: {v}', None)] + click_state[2].append(out['generated_captions']['raw_caption']) + image_output_mask = out['mask_save_path'] + image_output_crop = out['crop_save_path'] + return state, state, click_state, image_output_mask, image_output_crop + + +def upload_callback(image_input, state): + state = state + [('Image size: ' + str(image_input.size), None)] + return state + +# get coordinate in format [[x,y,positive/negative]] +def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData): + print("point_prompt: ", point_prompt) + if point_prompt == 'Positive Point': + coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1])) + else: + coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1])) + return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state) + +def chat_with_points(chat_input, click_state, state): + points, labels, captions = click_state + point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: " + # "The image is of width {width} and height {height}." + + prev_visual_context = "" + pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1] + prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n' + chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input}) + response = model.text_refiner.llm(chat_prompt) + state = state + [(chat_input, response)] + return state, state + +def init_openai_api_key(api_key): + os.environ['OPENAI_API_KEY'] = api_key + global model + model = CaptionAnything(args) + +css=''' +#image_upload{min-height:200px} +#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px} +''' + +with gr.Blocks(css=css) as iface: + state = gr.State([]) + click_state = gr.State([[],[],[]]) + caption_state = gr.State([[]]) + gr.Markdown(title) + gr.Markdown(description) + + with gr.Column(): + openai_api_key = gr.Textbox( + placeholder="Input your openAI API key and press Enter", + show_label=False, + lines=1, + type="password", + ) + openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key]) + + with gr.Row(): + with gr.Column(scale=0.7): + image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0) + + with gr.Row(scale=0.7): + point_prompt = gr.Radio( + choices=["Positive Point", "Negative Point"], + value="Positive Point", + label="Points", + interactive=True, + ) + + # with gr.Row(): + language = gr.Radio( + choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"], + value="English", + label="Language", + interactive=True, + ) + sentiment = gr.Radio( + choices=["Positive", "Natural", "Negative"], + value="Natural", + label="Sentiment", + interactive=True, + ) + factuality = gr.Radio( + choices=["Factual", "Imagination"], + value="Factual", + label="Factuality", + interactive=True, + ) + length = gr.Slider( + minimum=5, + maximum=100, + value=10, + step=1, + interactive=True, + label="Length", + ) + + with gr.Column(scale=1.5): + with gr.Row(): + image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0) + image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0) + chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5) + + with gr.Row(): + with gr.Column(scale=0.7): + prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])") + prompt_input.submit( + inference_seg_cap, + [ + image_input, + prompt_input, + language, + sentiment, + factuality, + length, + state, + click_state + ], + [chatbot, state, click_state, image_output_mask, image_output_crop], + show_progress=False + ) + + image_input.upload( + upload_callback, + [image_input, state], + [chatbot] + ) + + with gr.Row(): + clear_button = gr.Button(value="Clear Click", interactive=True) + clear_button.click( + lambda: ("", [[], [], []], None, None), + [], + [prompt_input, click_state, image_output_mask, image_output_crop], + queue=False, + show_progress=False + ) + + clear_button = gr.Button(value="Clear", interactive=True) + clear_button.click( + lambda: ("", [], [], [[], [], []], None, None), + [], + [prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop], + queue=False, + show_progress=False + ) + + submit_button = gr.Button( + value="Submit", interactive=True, variant="primary" + ) + submit_button.click( + inference_seg_cap, + [ + image_input, + prompt_input, + language, + sentiment, + factuality, + length, + state, + click_state + ], + [chatbot, state, click_state, image_output_mask, image_output_crop], + show_progress=False + ) + + # select coordinate + image_input.select( + get_select_coords, + inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state], + outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop], + show_progress=False + ) + + image_input.change( + lambda: ("", [], [[], [], []]), + [], + [chatbot, state, click_state], + queue=False, + ) + + with gr.Column(scale=1.5): + chat_input = gr.Textbox(lines=1, label="Chat Input") + chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state]) + + + examples = gr.Examples( + examples=examples, + inputs=[image_input, prompt_input], + ) + +iface.queue(concurrency_count=1, api_open=False, max_size=10) +iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share) diff --git a/caas.py b/caas.py new file mode 100644 index 0000000000000000000000000000000000000000..a2ac7c91b24682a7cd255b99eba2729e29d3aa0f --- /dev/null +++ b/caas.py @@ -0,0 +1,114 @@ +from captioner import build_captioner, BaseCaptioner +from segmenter import build_segmenter +from text_refiner import build_text_refiner +import os +import argparse +import pdb +import time +from PIL import Image + +class CaptionAnything(): + def __init__(self, args): + self.args = args + self.captioner = build_captioner(args.captioner, args.device, args) + self.segmenter = build_segmenter(args.segmenter, args.device, args) + if not args.disable_gpt: + self.init_refiner() + + + def init_refiner(self): + if os.environ.get('OPENAI_API_KEY', None): + self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args) + + def inference(self, image, prompt, controls, disable_gpt=False): + # segment with prompt + print("CA prompt: ", prompt, "CA controls",controls) + seg_mask = self.segmenter.inference(image, prompt)[0, ...] + mask_save_path = f'result/mask_{time.time()}.png' + if not os.path.exists(os.path.dirname(mask_save_path)): + os.makedirs(os.path.dirname(mask_save_path)) + new_p = Image.fromarray(seg_mask.astype('int') * 255.) + if new_p.mode != 'RGB': + new_p = new_p.convert('RGB') + new_p.save(mask_save_path) + print('seg_mask path: ', mask_save_path) + print("seg_mask.shape: ", seg_mask.shape) + # captioning with mask + if self.args.enable_reduce_tokens: + caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box) + else: + caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box) + # refining with TextRefiner + context_captions = [] + if self.args.context_captions: + context_captions.append(self.captioner.inference(image)) + if not disable_gpt and hasattr(self, "text_refiner"): + refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions) + else: + refined_caption = {'raw_caption': caption} + out = {'generated_captions': refined_caption, + 'crop_save_path': crop_save_path, + 'mask_save_path': mask_save_path, + 'context_captions': context_captions} + return out + +def parse_augment(): + parser = argparse.ArgumentParser() + parser.add_argument('--captioner', type=str, default="blip") + parser.add_argument('--segmenter', type=str, default="base") + parser.add_argument('--text_refiner', type=str, default="base") + parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth") + parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning") + parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions") + parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption") + parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box") + parser.add_argument('--device', type=str, default="cuda:0") + parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications") + parser.add_argument('--debug', action="store_true") + parser.add_argument('--gradio_share', action="store_true") + parser.add_argument('--disable_gpt', action="store_true") + parser.add_argument('--enable_reduce_tokens', action="store_true", default=False) + parser.add_argument('--disable_reuse_features', action="store_true", default=False) + args = parser.parse_args() + + if args.debug: + print(args) + return args + +if __name__ == "__main__": + args = parse_augment() + # image_path = 'test_img/img3.jpg' + image_path = 'test_img/img13.jpg' + prompts = [ + { + "prompt_type":["click"], + "input_point":[[500, 300], [1000, 500]], + "input_label":[1, 0], + "multimask_output":"True", + }, + { + "prompt_type":["click"], + "input_point":[[900, 800]], + "input_label":[1], + "multimask_output":"True", + } + ] + controls = { + "length": "30", + "sentiment": "positive", + # "imagination": "True", + "imagination": "False", + "language": "English", + } + + model = CaptionAnything(args) + for prompt in prompts: + print('*'*30) + print('Image path: ', image_path) + image = Image.open(image_path) + print(image) + print('Visual controls (SAM prompt):\n', prompt) + print('Language controls:\n', controls) + out = model.inference(image_path, prompt, controls) + + \ No newline at end of file diff --git a/captioner/README.md b/captioner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e9b387fade5888f6f4330aecfc0d1cdbb1c51703 --- /dev/null +++ b/captioner/README.md @@ -0,0 +1,13 @@ +To run BLIP/BLIP2, you should install transformers from source! +``` +!pip install git+https://github.com/huggingface/transformers.git +``` +To run filter module, you should install CLIP repo as a Python package as follow: +``` +!pip install ftfy regex tqdm +!pip install git+https://github.com/openai/CLIP.git +``` +To accelerate BLIP2 with int8, you should install accelerate +``` +!pip install accelerate bitsandbytes +``` diff --git a/captioner/__init__.py b/captioner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70952876daafe2a9479aa90b3845e58de6376711 --- /dev/null +++ b/captioner/__init__.py @@ -0,0 +1,15 @@ +from .blip import BLIPCaptioner +from .blip2 import BLIP2Captioner +from .git import GITCaptioner +from .base_captioner import BaseCaptioner + + +def build_captioner(type, device, args=None): + if type == 'blip': + return BLIPCaptioner(device, enable_filter=args.clip_filter) + elif type == 'blip2': + return BLIP2Captioner(device, enable_filter=args.clip_filter) + elif type == 'git': + return GITCaptioner(device, enable_filter=args.clip_filter) + else: + raise NotImplementedError("") \ No newline at end of file diff --git a/captioner/base_captioner.py b/captioner/base_captioner.py new file mode 100644 index 0000000000000000000000000000000000000000..cb442d3cab7f67c3120a5e7ddc27deba097be177 --- /dev/null +++ b/captioner/base_captioner.py @@ -0,0 +1,199 @@ +import torch +from PIL import Image, ImageDraw, ImageOps +from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering +import json +import pdb +import cv2 +import numpy as np +from typing import Union +import time +import clip + +def boundary(inputs): + + col = inputs.shape[1] + inputs = inputs.reshape(-1) + lens = len(inputs) + + for i in range(lens): + if inputs[i] != False: + break + for j in range(lens): + if inputs[lens - 1 - j] != False: + break + start = i + end = lens - 1 - j + top = start // col + bottom = end // col + + return top, bottom + +def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]): + + if type(seg_mask) == str: + seg_mask = Image.open(seg_mask) + elif type(seg_mask) == np.ndarray: + seg_mask = Image.fromarray(seg_mask) + seg_mask = np.array(seg_mask) > 0 + size = max(seg_mask.shape[0], seg_mask.shape[1]) + top, bottom = boundary(seg_mask) + left, right = boundary(seg_mask.T) + return [left / size, top / size, right / size, bottom / size] + +def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]): + if type(seg_mask) == str: + seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE) + _, seg_mask = cv2.threshold(seg_mask, 127, 255, 0) + elif type(seg_mask) == np.ndarray: + assert seg_mask.ndim == 2 # only support single-channel segmentation mask + seg_mask = seg_mask.astype('uint8') + if seg_mask.dtype == 'bool': + seg_mask = seg_mask * 255 + contours, hierarchy = cv2.findContours(seg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours = np.concatenate(contours, axis=0) + rect = cv2.minAreaRect(contours) + box = cv2.boxPoints(rect) + if rect[-1] >= 45: + newstart = box.argmin(axis=0)[1] # leftmost + else: + newstart = box.argmax(axis=0)[0] # topmost + box = np.concatenate([box[newstart:], box[:newstart]], axis=0) + box = np.int0(box) + return box + +def get_w_h(rect_points): + w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int') + h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int') + return w, h + +def cut_box(img, rect_points): + w, h = get_w_h(rect_points) + dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0],], dtype="float32") + transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts) + cropped_img = cv2.warpPerspective(img, transform, (h, w)) + return cropped_img + +class BaseCaptioner: + def __init__(self, device, enable_filter=False): + print(f"Initializing ImageCaptioning to {device}") + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.processor = None + self.model = None + self.enable_filter = enable_filter + if enable_filter: + self.filter, self.preprocess = clip.load('ViT-B/32', device) + self.threshold = 0.2 + + @torch.no_grad() + def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str): + + if type(image) == str: # input path + image = Image.open(image) + elif type(image) == np.ndarray: + image = Image.fromarray(image) + + image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224) + text = clip.tokenize(caption).to(self.device) # (1, 77) + image_features = self.filter.encode_image(image) # (1, 512) + text_features = self.filter.encode_text(text) # (1, 512) + image_features /= image_features.norm(dim = -1, keepdim = True) + text_features /= text_features.norm(dim = -1, keepdim = True) + similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item() + if similarity < self.threshold: + print('There seems to be nothing where you clicked.') + out = "" + else: + out = caption + print(f'Clip score of the caption is {similarity}') + return out + + + def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool=False): + raise NotImplementedError() + + def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool=False): + raise NotImplementedError() + + def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False): + if type(image) == str: # input path + image = Image.open(image) + elif type(image) == np.ndarray: + image = Image.fromarray(image) + + if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners + size = max(image.width, image.height) + x1, y1, x2, y2 = box + image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size))) + elif np.array(box).size == 8: # four corners of an irregular rectangle + image_crop = cut_box(np.array(image), box) + + crop_save_path = f'result/crop_{time.time()}.png' + Image.fromarray(image_crop).save(crop_save_path) + print(f'croped image saved in {crop_save_path}') + caption = self.inference(image_crop, filter) + return caption, crop_save_path + + + def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, regular_box = False): + if type(image) == str: + image = Image.open(image) + if type(seg_mask) == str: + seg_mask = Image.open(seg_mask) + elif type(seg_mask) == np.ndarray: + seg_mask = Image.fromarray(seg_mask) + seg_mask = seg_mask.resize(image.size) + seg_mask = np.array(seg_mask) > 0 + + if crop_mode=="wo_bg": + image = np.array(image) * seg_mask[:,:,np.newaxis] + else: + image = np.array(image) + + if regular_box: + min_area_box = new_seg_to_box(seg_mask) + else: + min_area_box = seg_to_box(seg_mask) + return self.inference_box(image, min_area_box, filter) + + + def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", regular_box = False): + if type(image) == str: + image = Image.open(image) + if type(seg_mask) == str: + seg_mask = Image.open(seg_mask) + elif type(seg_mask) == np.ndarray: + seg_mask = Image.fromarray(seg_mask) + seg_mask = seg_mask.resize(image.size) + seg_mask = np.array(seg_mask) > 0 + + if crop_mode=="wo_bg": + image = np.array(image) * seg_mask[:,:,np.newaxis] + else: + image = np.array(image) + + if regular_box: + box = new_seg_to_box(seg_mask) + else: + box = seg_to_box(seg_mask) + + if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners + size = max(image.shape[0], image.shape[1]) + x1, y1, x2, y2 = box + image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size))) + elif np.array(box).size == 8: # four corners of an irregular rectangle + image_crop = cut_box(np.array(image), box) + crop_save_path = f'result/crop_{time.time()}.png' + Image.fromarray(image_crop).save(crop_save_path) + print(f'croped image saved in {crop_save_path}') + return crop_save_path + + +if __name__ == '__main__': + model = BaseCaptioner(device='cuda:0') + image_path = 'test_img/img2.jpg' + seg_mask = np.zeros((15,15)) + seg_mask[5:10, 5:10] = 1 + seg_mask = 'image/SAM/img10.jpg.raw_mask.png' + print(model.inference_seg(image_path, seg_mask)) + \ No newline at end of file diff --git a/captioner/blip.py b/captioner/blip.py new file mode 100644 index 0000000000000000000000000000000000000000..35ed20b3f85bc3de9fab56d2c1a2fe98aba7ca31 --- /dev/null +++ b/captioner/blip.py @@ -0,0 +1,66 @@ +import torch +from PIL import Image, ImageDraw, ImageOps +from transformers import BlipProcessor +from .modeling_blip import BlipForConditionalGeneration +import json +import pdb +import cv2 +import numpy as np +from typing import Union +from .base_captioner import BaseCaptioner +import torchvision.transforms.functional as F + + +class BLIPCaptioner(BaseCaptioner): + def __init__(self, device, enable_filter=False): + super().__init__(device, enable_filter) + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") + self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=self.torch_dtype).to(self.device) + + @torch.no_grad() + def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False): + if type(image) == str: # input path + image = Image.open(image) + inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype) + out = self.model.generate(**inputs, max_new_tokens=50) + captions = self.processor.decode(out[0], skip_special_tokens=True) + if self.enable_filter and filter: + captions = self.filter_caption(image, captions) + print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}") + return captions + + @torch.no_grad() + def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, regular_box = False): + crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, regular_box=regular_box) + if type(image) == str: # input path + image = Image.open(image) + inputs = self.processor(image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype) + _, _, H, W = pixel_values.shape + seg_mask = Image.fromarray(seg_mask.astype(float)) + seg_mask = seg_mask.resize((H, W)) + seg_mask = F.pil_to_tensor(seg_mask) > 0.5 + seg_mask = seg_mask.float() + pixel_masks = seg_mask.unsqueeze(0).to(self.device) + out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50) + captions = self.processor.decode(out[0], skip_special_tokens=True) + if self.enable_filter and filter: + captions = self.filter_caption(image, captions) + print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}") + return captions, crop_save_path + + +if __name__ == '__main__': + model = BLIPCaptioner(device='cuda:0') + # image_path = 'test_img/img2.jpg' + image_path = '/group/30042/wybertwang/project/woa_visgpt/chatARC/image/SAM/img10.jpg' + seg_mask = np.zeros((15,15)) + seg_mask[5:10, 5:10] = 1 + seg_mask = 'test_img/img10.jpg.raw_mask.png' + image_path = 'test_img/img2.jpg' + seg_mask = 'test_img/img2.jpg.raw_mask.png' + print(f'process image {image_path}') + print(model.inference_with_reduced_tokens(image_path, seg_mask)) + \ No newline at end of file diff --git a/captioner/blip2.py b/captioner/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..1d85fd0950604feebfb9a5038673a5ba69d513e8 --- /dev/null +++ b/captioner/blip2.py @@ -0,0 +1,55 @@ +import torch +from PIL import Image, ImageDraw, ImageOps +from transformers import AutoProcessor, Blip2ForConditionalGeneration +import json +import pdb +import cv2 +import numpy as np +from typing import Union +from .base_captioner import BaseCaptioner + +class BLIP2Captioner(BaseCaptioner): + def __init__(self, device, dialogue: bool = False, enable_filter: bool = False): + super().__init__(device, enable_filter) + self.device = device + self.dialogue = dialogue + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map = 'sequential', load_in_8bit=True) + @torch.no_grad() + def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False): + if type(image) == str: # input path + image = Image.open(image) + + if not self.dialogue: + inputs = self.processor(image, text = 'Ignore the black background! This is a photo of ', return_tensors="pt").to(self.device, self.torch_dtype) + out = self.model.generate(**inputs, max_new_tokens=50) + captions = self.processor.decode(out[0], skip_special_tokens=True) + if self.enable_filter and filter: + captions = self.filter_caption(image, captions) + print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}") + return captions + else: + context = [] + template = "Question: {} Answer: {}." + while(True): + input_texts = input() + if input_texts == 'end': + break + prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:" + inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype) + out = self.model.generate(**inputs, max_new_tokens=50) + captions = self.processor.decode(out[0], skip_special_tokens=True).strip() + context.append((input_texts, captions)) + + return captions + +if __name__ == '__main__': + + dialogue = False + model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache') + image_path = 'test_img/img2.jpg' + seg_mask = np.zeros((224,224)) + seg_mask[50:200, 50:200] = 1 + print(f'process image {image_path}') + print(model.inference_seg(image_path, seg_mask)) \ No newline at end of file diff --git a/captioner/git.py b/captioner/git.py new file mode 100644 index 0000000000000000000000000000000000000000..6694ad30dea801567b7cd05dc1f5488df9becf4e --- /dev/null +++ b/captioner/git.py @@ -0,0 +1,57 @@ +from transformers import GitProcessor, AutoProcessor +from .modeling_git import GitForCausalLM +from PIL import Image +import torch +from .base_captioner import BaseCaptioner +import numpy as np +from typing import Union +import torchvision.transforms.functional as F + + +class GITCaptioner(BaseCaptioner): + def __init__(self, device, enable_filter=False): + super().__init__(device, enable_filter) + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.processor = AutoProcessor.from_pretrained("microsoft/git-large") + self.model = GitForCausalLM.from_pretrained("microsoft/git-large", torch_dtype=self.torch_dtype).to(self.device) + + @torch.no_grad() + def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False): + if type(image) == str: # input path + image = Image.open(image) + pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype) + generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50) + generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + if self.enable_filter and filter: + captions = self.filter_caption(image, captions) + print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}") + return generated_caption + + @torch.no_grad() + def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, regular_box = False): + crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, regular_box=regular_box) + if type(image) == str: # input path + image = Image.open(image) + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype) + _, _, H, W = pixel_values.shape + seg_mask = Image.fromarray(seg_mask.astype(float)) + seg_mask = seg_mask.resize((H, W)) + seg_mask = F.pil_to_tensor(seg_mask) > 0.5 + seg_mask = seg_mask.float() + pixel_masks = seg_mask.unsqueeze(0).to(self.device) + out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50) + captions = self.processor.decode(out[0], skip_special_tokens=True) + if self.enable_filter and filter: + captions = self.filter_caption(image, captions) + print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}") + return captions, crop_save_path + +if __name__ == '__main__': + model = GITCaptioner(device='cuda:2', enable_filter=False) + image_path = 'test_img/img2.jpg' + seg_mask = np.zeros((224,224)) + seg_mask[50:200, 50:200] = 1 + print(f'process image {image_path}') + print(model.inference_with_reduced_tokens(image_path, seg_mask)) \ No newline at end of file diff --git a/captioner/modeling_blip.py b/captioner/modeling_blip.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2303b05fa885ce425e2249ff97765cfe38ac20 --- /dev/null +++ b/captioner/modeling_blip.py @@ -0,0 +1,1476 @@ +# coding=utf-8 +# Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch BLIP model.""" + +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn.functional import normalize + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.blip.configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig +from transformers.models.blip.modeling_blip_text import BlipTextLMHeadModel, BlipTextModel +from .vit_pixel_masks_utils import ViTPatchMaskGenerator + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "Salesforce/blip-vqa-base" + +BLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "Salesforce/blip-vqa-base", + "Salesforce/blip-vqa-capfit-large", + "Salesforce/blip-image-captioning-base", + "Salesforce/blip-image-captioning-large", + "Salesforce/blip-itm-base-coco", + "Salesforce/blip-itm-large-coco", + "Salesforce/blip-itm-base-flikr", + "Salesforce/blip-itm-large-flikr", + # See all BLIP models at https://huggingface.co/models?filter=blip +] + + +# Copied from transformers.models.clip.modeling_clip.contrastive_loss +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip +def blip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class BlipForConditionalGenerationModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder. + + Args: + loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Languge modeling loss from the text decoder. + decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head of the text decoder model. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*): + The image embeddings obtained after applying the Vision Transformer model to the input image. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[Tuple[torch.FloatTensor]] = None + decoder_logits: Optional[Tuple[torch.FloatTensor]] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BlipTextVisionModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Languge modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BlipImageTextMatchingModelOutput(ModelOutput): + """ + Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the + last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity + scores. + + Args: + itm_score (`torch.FloatTensor`): + The image-text similarity scores. + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Languge modeling loss from the text decoder. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*): + Last layer hidden-state of the vision of the vision-only branch of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + question_embeds (`torch.FloatTensor`): + The question embeddings obtained by the text projection layer. + """ + + itm_score: Optional[torch.FloatTensor] = None + loss: Optional[torch.FloatTensor] = None + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + vision_pooler_output: Optional[torch.FloatTensor] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + question_embeds: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BlipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`BlipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`BlipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class BlipVisionEmbeddings(nn.Module): + def __init__(self, config: BlipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.randn(1, 1, self.embed_dim), + ) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip +class BlipTextEmbeddings(nn.Module): + def __init__(self, config: BlipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class BlipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = nn.Dropout(config.attention_dropout) + + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim) + + self.projection = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + mixed_qkv = self.qkv(hidden_states) + mixed_qkv = ( + self.qkv(hidden_states) + .reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.projection(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip +class BlipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class BlipEncoderLayer(nn.Module): + def __init__(self, config: BlipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = BlipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = BlipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class BlipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BlipConfig + base_model_prefix = "blip" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_range + if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=factor) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + if isinstance(module, BlipVisionEmbeddings): + if hasattr(self.config, "vision_config"): + factor = self.config.vision_config.initializer_range + nn.init.trunc_normal_( + module.position_embedding, + mean=0.0, + std=factor, + ) + + nn.init.trunc_normal_( + module.class_embedding, + mean=0.0, + std=factor, + ) + + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BlipEncoder): + module.gradient_checkpointing = value + + +BLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BlipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +BLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +BLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +BLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoProcessor`]. See [`BlipProcessor.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`BlipImageProcessor`]. See [`BlipImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class BlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`BlipEncoderLayer`]. + + Args: + config (`BlipConfig`): + The corresponding vision configuration for the `BlipEncoder`. + """ + + def __init__(self, config: BlipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class BlipVisionModel(BlipPreTrainedModel): + main_input_name = "pixel_values" + config_class = BlipVisionConfig + + def __init__(self, config: BlipVisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = BlipVisionEmbeddings(config) + self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size) + self.encoder = BlipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=BlipVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_masks: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + B, N, D = hidden_states.shape + # print('Before mask:', hidden_states.shape) + if pixel_masks is not None: + assert pixel_masks.shape[0] == 1 + patch_masks = self.patch_mask_generator(pixel_masks) + # print(patch_masks.shape) + patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states) + hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D) + # print('After mask:', hidden_states.shape) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +@add_start_docstrings(BLIP_START_DOCSTRING) +class BlipModel(BlipPreTrainedModel): + config_class = BlipConfig + + def __init__(self, config: BlipConfig): + super().__init__(config) + + if not isinstance(config.text_config, BlipTextConfig): + raise ValueError( + "config.text_config is expected to be of type BlipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, BlipVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type BlipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.projection_dim = config.projection_dim + self.text_embed_dim = text_config.hidden_size + self.vision_embed_dim = vision_config.hidden_size + + self.text_model = BlipTextModel(text_config) + self.vision_model = BlipVisionModel(vision_config) + + self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) + self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) + self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`BlipTextModel`]. + + Examples: + + ```python + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + >>> text_features = model.get_text_features(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features + + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`BlipVisionModel`]. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> image_features = model.get_image_features(**inputs) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] # pooled_output + image_features = self.visual_projection(pooled_output) + + return image_features + + @add_start_docstrings_to_model_forward(BLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BlipOutput, config_class=BlipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_masks: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BlipOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipModel + + >>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor( + ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True + ... ) + + >>> outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score + >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities + ```""" + # Use BLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + pixel_masks=pixel_masks, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + image_embeds = self.visual_projection(image_embeds) + + text_embeds = text_outputs[1] + text_embeds = self.text_projection(text_embeds) + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + loss = blip_loss(logits_per_text) + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return BlipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +@add_start_docstrings( + """ + BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass + `input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise, + the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption + from the text input. If no text input is provided, the decoder will start with the [BOS] token only. + """, + BLIP_START_DOCSTRING, +) +class BlipForConditionalGeneration(BlipPreTrainedModel): + config_class = BlipConfig + _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] + main_input_name = "pixel_values" + + def __init__(self, config: BlipConfig): + super().__init__(config) + + self.vision_model = BlipVisionModel(config.vision_config) + + self.text_decoder = BlipTextLMHeadModel(config.text_config) + + self.decoder_input_ids = config.text_config.bos_token_id + self.decoder_pad_token_id = config.text_config.pad_token_id + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BlipForConditionalGenerationModelOutput, config_class=BlipVisionConfig) + def forward( + self, + pixel_values: torch.FloatTensor, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BlipForConditionalGenerationModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "A picture of" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + + >>> outputs = model(**inputs) + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + + outputs = self.text_decoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + labels=labels, + return_dict=return_dict, + reduction="mean", + ) + + if not return_dict: + outputs = (outputs[0], outputs[1], image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return BlipForConditionalGenerationModelOutput( + loss=outputs.loss, + decoder_logits=outputs.logits, + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + pixel_values: torch.FloatTensor, + pixel_masks: torch.Tensor = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + **generate_kwargs, + ) -> torch.LongTensor: + r""" + Overrides *generate* function to be able to use the model as a conditional generator + + Parameters: + pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*: + Input image to be processed + input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): + The sequence used as a prompt for the generation. + attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForConditionalGeneration + + >>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + two cats are laying on a couch + ``` + """ + + batch_size = pixel_values.shape[0] + vision_outputs = self.vision_model( + pixel_values=pixel_values, + pixel_masks=pixel_masks, + ) + + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + + if isinstance(input_ids, list): + input_ids = torch.LongTensor(input_ids) + elif input_ids is None: + input_ids = ( + torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]]) + .repeat(batch_size, 1) + .to(image_embeds.device) + ) + + input_ids[:, 0] = self.config.text_config.bos_token_id + attention_mask = attention_mask[:, :-1] if attention_mask is not None else None + + outputs = self.text_decoder.generate( + input_ids=input_ids[:, :-1], + eos_token_id=self.config.text_config.sep_token_id, + pad_token_id=self.config.text_config.pad_token_id, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + **generate_kwargs, + ) + + return outputs + + +@add_start_docstrings( + """ + BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text + decoder. The vision encoder will encode the input image, the text encoder will encode the input question together + with the encoding of the image, and the text decoder will output the answer to the question. + """, + BLIP_START_DOCSTRING, +) +class BlipForQuestionAnswering(BlipPreTrainedModel): + config_class = BlipConfig + _keys_to_ignore_on_load_missing = [r"text_decoder.cls.predictions.decoder.bias"] + + def __init__(self, config: BlipConfig): + super().__init__(config) + + self.vision_model = BlipVisionModel(config.vision_config) + + self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False) + + self.text_decoder = BlipTextLMHeadModel(config.text_config) + + self.decoder_pad_token_id = config.text_config.pad_token_id + self.decoder_start_token_id = config.text_config.bos_token_id + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + # Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right + def _shift_right(self, input_ids): + pad_token_id = self.decoder_pad_token_id + + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig) + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BlipTextVisionModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForQuestionAnswering + + >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> # training + >>> text = "How many cats are in the picture?" + >>> label = "2" + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> labels = processor(text=label, return_tensors="pt").input_ids + + >>> inputs["labels"] = labels + >>> outputs = model(**inputs) + >>> loss = outputs.loss + >>> loss.backward() + + >>> # inference + >>> text = "How many cats are in the picture?" + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + 2 + ```""" + if labels is None and decoder_input_ids is None: + raise ValueError( + "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with" + " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you" + " are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`" + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long) + + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=return_dict, + ) + + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + if labels is not None and decoder_input_ids is None: + # get decoder inputs from shifting lm labels to the right - this is used in training mode + decoder_input_ids = self._shift_right(labels) + # replace possible -100 values in labels by `pad_token_id` + labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100) + + answer_output = self.text_decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=question_embeds, + encoder_attention_mask=attention_mask, + labels=labels, + return_dict=return_dict, + reduction="mean", + ) + + if labels is not None: + decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean() + else: + decoder_loss = None + + if not return_dict: + outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:] + return tuple(output for output in outputs if output is not None) + + return BlipTextVisionModelOutput( + loss=decoder_loss, + image_embeds=image_embeds, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + ) + + @torch.no_grad() + def generate( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + pixel_masks: torch.Tensor = None, + attention_mask: Optional[torch.LongTensor] = None, + **generate_kwargs, + ) -> torch.LongTensor: + r""" + Overrides *generate* function to be able to use the model as a conditional generator + + Parameters: + input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*): + The sequence used as a prompt for the generation. + pixel_values (*torch.FloatTensor* of shape *(batch_size, image_width, image_height)*: + Input image to be processed + attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for + tokens that are NOT MASKED, `0` for MASKED tokens. + **generate_kwargs: + Additional arguments passed to the *generate* function of the decoder + + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForQuestionAnswering + + >>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "How many cats are in the picture?" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + + >>> outputs = model.generate(**inputs) + >>> print(processor.decode(outputs[0], skip_special_tokens=True)) + 2 + ``` + """ + vision_outputs = self.vision_model( + pixel_values=pixel_values, + pixel_masks=pixel_masks + ) + + image_embeds = vision_outputs[0] + + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + + if isinstance(input_ids, list): + input_ids = torch.LongTensor(input_ids) + + question_outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_attention_mask, + return_dict=False, + ) + + question_embeds = question_outputs[0] + + question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device) + + bos_ids = torch.full( + (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device + ) + + outputs = self.text_decoder.generate( + input_ids=bos_ids, + eos_token_id=self.config.text_config.sep_token_id, + pad_token_id=self.config.text_config.pad_token_id, + encoder_hidden_states=question_embeds, + encoder_attention_mask=question_attention_mask, + **generate_kwargs, + ) + + return outputs + + +@add_start_docstrings( + """ + BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of + image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to + the image. + """, + BLIP_START_DOCSTRING, +) +class BlipForImageTextRetrieval(BlipPreTrainedModel): + config_class = BlipConfig + + def __init__(self, config: BlipConfig): + super().__init__(config) + + self.vision_model = BlipVisionModel(config.vision_config) + + self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False) + + # vision projection layer + self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size) + + # text projection layer + self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size) + + # image text matching head + self.itm_head = nn.Linear(config.text_config.hidden_size, 2) + + self.decoder_pad_token_id = ( + config.text_config.pad_token_id + if not hasattr(config, "decoder_pad_token_id") + else config.decoder_pad_token_id + ) + self.decoder_start_token_id = ( + config.text_config.bos_token_id + if not hasattr(config, "decoder_start_token_id") + else config.decoder_start_token_id + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig) + def forward( + self, + input_ids: torch.LongTensor, + pixel_values: torch.FloatTensor, + use_itm_head: Optional[bool] = True, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BlipTextVisionModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, BlipForImageTextRetrieval + + >>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco") + >>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "an image of a cat" + + >>> inputs = processor(images=image, text=text, return_tensors="pt") + >>> outputs = model(**inputs) + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[0] + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long) + + if use_itm_head: + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + output = self.itm_head(question_embeds[:, 0, :]) + else: + question_embeds = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + return_dict=return_dict, + ) + question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state + + image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1) + + output = image_feat @ text_feat.t() + + if not return_dict: + outputs = (output, vision_outputs[0]) + vision_outputs[2:] + (question_embeds,) + return tuple(output for output in outputs if output is not None) + + return BlipImageTextMatchingModelOutput( + itm_score=output, + last_hidden_state=vision_outputs.last_hidden_state, + hidden_states=vision_outputs.hidden_states, + attentions=vision_outputs.attentions, + question_embeds=question_embeds, + ) diff --git a/captioner/modeling_git.py b/captioner/modeling_git.py new file mode 100644 index 0000000000000000000000000000000000000000..458d925036a5201d7c46783c71511874a3cc02d3 --- /dev/null +++ b/captioner/modeling_git.py @@ -0,0 +1,1587 @@ +# coding=utf-8 +# Copyright 2022 Microsoft Research and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch GIT model.""" + + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.file_utils import ModelOutput +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.models.git.configuration_git import GitConfig, GitVisionConfig +from .vit_pixel_masks_utils import ViTPatchMaskGenerator + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/git-base" +_CONFIG_FOR_DOC = "GitConfig" + +GIT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/git-base", + # See all GIT models at https://huggingface.co/models?filter=git +] + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git +class GitVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class GitEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values_length: int = 0, + ) -> torch.Tensor: + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + embeddings = self.word_embeddings(input_ids) + else: + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class GitSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1) + if config.num_image_with_embedding is not None: + self.image_patch_tokens *= config.num_image_with_embedding + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + pixel_values_present: Optional[bool] = False, + image_token_num: Optional[int] = None + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + if image_token_num is not None: + cutoff = image_token_num + else: + cutoff = self.image_patch_tokens if pixel_values_present else 0 + if past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([key_layer[:, :, :cutoff, :], past_key_value[0], key_layer[:, :, -1:, :]], dim=2) + value_layer = torch.cat( + [value_layer[:, :, :cutoff, :], past_key_value[1], value_layer[:, :, -1:, :]], dim=2 + ) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component. + past_key_value = ( + key_layer[:, :, cutoff:, :], + value_layer[:, :, cutoff:, :], + ) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in GitModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class GitSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class GitAttention(nn.Module): + # Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->Git + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = GitSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = GitSelfOutput(config) + self.pruned_heads = set() + + # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + pixel_values_present: Optional[bool] = False, + image_token_num: Optional[int] = None + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + past_key_value, + output_attentions, + pixel_values_present, + image_token_num + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class GitIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class GitOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class GitLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = GitAttention(config) + self.intermediate = GitIntermediate(config) + self.output = GitOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + pixel_values_present: Optional[bool] = False, + image_token_num: Optional[bool] = None, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + pixel_values_present=pixel_values_present, + image_token_num=image_token_num + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class GitEncoder(nn.Module): + # Copied from transformers.models.bert.modeling_bert.BertEncoder.__init__ with Bert->Git + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([GitLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + pixel_values_present: Optional[bool] = False, + image_token_num: Optional[int] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]: + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + past_key_value, + output_attentions, + pixel_values_present, + image_token_num, + + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class GitPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GitConfig + base_model_prefix = "git" + supports_gradient_checkpointing = True + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, GitVisionEmbeddings): + nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range) + nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range) + nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range) + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (GitEncoder, GitVisionEncoder)): + module.gradient_checkpointing = value + + +GIT_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GitConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GIT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See + [`CLIPImageProcessor.__call__`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git +class GitVisionEmbeddings(nn.Module): + def __init__(self, config: GitVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP +class GitVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPAttention +class GitVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scale + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->GitVision +class GitVisionEncoderLayer(nn.Module): + def __init__(self, config: GitVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = GitVisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = GitVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + causal_attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->GitVision, CLIPConfig +class GitVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`GitVisionEncoderLayer`]. + + Args: + config: GitVisionConfig + """ + + def __init__(self, config: GitVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + causal_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +GIT_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class GitVisionTransformer(nn.Module): + # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.__init__ with CLIPEncoder->GitVisionEncoder, CLIP->Git + def __init__(self, config: GitVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = GitVisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.patch_mask_generator = ViTPatchMaskGenerator(config.patch_size) + self.encoder = GitVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_masks: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + B, N, D = hidden_states.shape + # print('Before mask:', hidden_states.shape) + if pixel_masks is not None: + assert pixel_masks.shape[0] == 1 + patch_masks = self.patch_mask_generator(pixel_masks) + # print(patch_masks.shape) + patch_masks = patch_masks.unsqueeze(-1).expand_as(hidden_states) + hidden_states = hidden_states.masked_select(patch_masks).view(B, -1, D) + # print('After mask:', hidden_states.shape) + hidden_states = self.pre_layrnorm(hidden_states) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=last_hidden_state, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from CLIP, used in GIT, without any head or projection on top.""", + GIT_START_DOCSTRING, +) +class GitVisionModel(GitPreTrainedModel): + config_class = GitVisionConfig + main_input_name = "pixel_values" + + # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git + def __init__(self, config: GitVisionConfig): + super().__init__(config) + self.vision_model = GitVisionTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_masks: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, GitVisionModel + + >>> processor = AutoProcessor.from_pretrained("microsoft/git-base") + >>> model = GitVisionModel.from_pretrained("microsoft/git-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + pixel_masks=pixel_masks, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class GitProjection(nn.Module): + def __init__(self, config: GitConfig): + super().__init__() + self.config = config + self.visual_projection = nn.Sequential( + nn.Linear(config.vision_config.hidden_size, config.hidden_size), + nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps), + ) + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + return self.visual_projection(embeddings) + + +@add_start_docstrings( + "The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states" + " without any specific head on top.", + GIT_START_DOCSTRING, +) +class GitModel(GitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = GitEmbeddings(config) + self.image_encoder = GitVisionModel(config.vision_config) + self.encoder = GitEncoder(config) + + self.visual_projection = GitProjection(config) + + if config.num_image_with_embedding is not None: + self.img_temperal_embedding = nn.ParameterList( + nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size)) + for _ in range(config.num_image_with_embedding) + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + # Default mask is for forward direction. Flip for backward direction. + mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1) + mask = mask.masked_fill(mask == 1, float("-inf")) + return mask + + def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None): + num_tgt = tgt.shape[1] + num_memory = memory.shape[1] + device = tgt.device + dtype = tgt.dtype + top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype) + top_right = torch.full( + (num_memory, num_tgt + past_key_values_length), + float("-inf"), + device=tgt.device, + dtype=dtype, + ) + bottom_left = torch.zeros( + (num_tgt, num_memory), + dtype=dtype, + device=tgt_mask.device, + ) + + if past_key_values_length > 0: + tgt_mask = torch.zeros( + (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length), + dtype=dtype, + device=tgt_mask.device, + ) + + left = torch.cat((top_left, bottom_left), dim=0) + right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0) + + full_attention_mask = torch.cat((left, right), dim=1)[None, :] + + if memory_key_padding_mask is None: + memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device) + # if it is False, it means valid. That is, it is not a padding + if memory_key_padding_mask.dtype != torch.bool: + raise ValueError("Memory key padding mask must be a boolean tensor.") + zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype) + zero_negative_infinity[memory_key_padding_mask] = float("-inf") + full_attention_mask = full_attention_mask.expand( + (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt) + ) + full_attention_mask = full_attention_mask.clone() + origin_left = full_attention_mask[:, :, :num_memory] + update = zero_negative_infinity[:, None, :] + full_attention_mask[:, :, :num_memory] = origin_left + update + + # add axis for multi-head + full_attention_mask = full_attention_mask[:, None, :, :] + + return full_attention_mask + + @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_masks: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]: + r""" + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, AutoModel + >>> import requests + >>> from PIL import Image + + >>> processor = AutoProcessor.from_pretrained("microsoft/git-base") + >>> model = AutoModel.from_pretrained("microsoft/git-base") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = "this is an image of two cats" + + >>> inputs = processor(text, images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length = input_shape[1] + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + projected_visual_features = None + if pixel_values is not None: + if pixel_values.ndim == 4: + # here we assume pixel_values is of shape (batch_size, num_channels, height, width) + visual_features = self.image_encoder(pixel_values=pixel_values, pixel_masks=pixel_masks).last_hidden_state + + elif pixel_values.ndim == 5: + # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width) + visual_features = [] + for frame_idx in range(pixel_values.shape[1]): + visual_features_frame = self.image_encoder(pixel_values[:, frame_idx, :, :]).last_hidden_state + visual_features_frame += self.img_temperal_embedding[frame_idx] + visual_features.append(visual_features_frame) + + # finally, concatenate all features along sequence dimension + visual_features = torch.cat(visual_features, dim=1) + + else: + raise ValueError("pixel_values must be of rank 4 or 5") + + projected_visual_features = self.visual_projection(visual_features) + image_token_num = projected_visual_features.shape[1] + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if projected_visual_features is None: + projected_visual_features = torch.zeros( + (embedding_output.shape[0], 0, embedding_output.shape[2]), + dtype=embedding_output.dtype, + device=embedding_output.device, + ) + + # Repeat visual features to match embedding batch size. + projected_visual_features = projected_visual_features.repeat( + embedding_output.size(0) // projected_visual_features.size(0), 1, 1 + ) + + # concatenate patch token and text token embeddings + hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1) + + # By default, an additive causal mask is created + # for masking the future (one direction). + tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device) + + # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len) + combined_attention_mask = self.create_attention_mask( + tgt=embedding_output, + memory=projected_visual_features, + tgt_mask=tgt_mask, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # if the user provides an attention mask, we add it to the default one + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]).to( + embedding_output.device + ) + if past_key_values_length > 0: + expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :] + else: + combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=combined_attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + pixel_values_present=pixel_values is not None, + image_token_num=image_token_num + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutputWithPast( + last_hidden_state=sequence_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING +) +class GitForCausalLM(GitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.git = GitModel(config) + self.output = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_masks: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Examples: + + Image captioning example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM + >>> import requests + >>> from PIL import Image + + >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco") + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values + + >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50) + >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print(generated_caption) + two cats sleeping on a pink blanket next to remotes. + ``` + + Visual question answering (VQA) example: + + ```python + >>> from transformers import AutoProcessor, AutoModelForCausalLM + >>> from huggingface_hub import hf_hub_download + >>> from PIL import Image + + >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa") + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa") + + >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset") + >>> image = Image.open(file_path).convert("RGB") + + >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values + + >>> question = "what does the front of the bus say at the top?" + + >>> input_ids = processor(text=question, add_special_tokens=False).input_ids + >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids + >>> input_ids = torch.tensor(input_ids).unsqueeze(0) + + >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50) + >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True)) + ['what does the front of the bus say at the top? special'] + ``` + + Video captioning example: + + ```python + >>> import av + >>> import numpy as np + >>> from PIL import Image + >>> from huggingface_hub import hf_hub_download + >>> from transformers import AutoProcessor, AutoModelForCausalLM + + >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex") + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex") + + >>> # set seed for reproducability + >>> np.random.seed(45) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... converted_len = int(clip_len * frame_sample_rate) + ... end_idx = np.random.randint(converted_len, seg_len) + ... start_idx = end_idx - converted_len + ... indices = np.linspace(start_idx, end_idx, num=clip_len) + ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) + ... return indices + + + >>> # load video + >>> file_path = hf_hub_download( + ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" + ... ) + >>> container = av.open(file_path) + + >>> # sample frames + >>> num_frames = model.config.num_image_with_embedding + >>> indices = sample_frame_indices( + ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames + ... ) + >>> frames = read_video_pyav(container, indices) + + >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values + + >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50) + + >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True)) + Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.git( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_masks=pixel_masks, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.output(sequence_output) + + loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens + shifted_logits = logits[:, num_image_tokens:-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + input_shape = input_ids.shape + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "pixel_values": kwargs.get("pixel_values", None), + "pixel_masks": kwargs.get("pixel_masks", None), + "past_key_values": past_key_values, + "use_cache": use_cache, + } + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/captioner/vit_pixel_masks_utils.py b/captioner/vit_pixel_masks_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ecccbe54b9d4cd468839d6fd8e8651884b9ab07a --- /dev/null +++ b/captioner/vit_pixel_masks_utils.py @@ -0,0 +1,17 @@ + +import torch +import torch.nn as nn + + +class ViTPatchMaskGenerator(nn.Module): + def __init__(self, patch_size) -> None: + super(ViTPatchMaskGenerator, self).__init__() + self.patch_size = patch_size + self.pool = nn.MaxPool2d(kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_masks): + patch_mask = self.pool(pixel_masks) + patch_mask = patch_mask.bool().flatten(1) + cls_token_mask = patch_mask.new_ones([patch_mask.shape[0], 1]).bool() + patch_mask = torch.cat([cls_token_mask, patch_mask], dim=-1) + return patch_mask diff --git a/image_editing_utils.py b/image_editing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b143aeb08c9056dde2d2d1ecef2c658f2ea133 --- /dev/null +++ b/image_editing_utils.py @@ -0,0 +1,68 @@ +from PIL import Image, ImageDraw, ImageFont +import copy +import numpy as np + +def wrap_text(text, font, max_width): + lines = [] + words = text.split(' ') + current_line = '' + + for word in words: + if font.getsize(current_line + word)[0] <= max_width: + current_line += word + ' ' + else: + lines.append(current_line) + current_line = word + ' ' + + lines.append(current_line) + return lines + +def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.ttf', font_size_ratio=0.033): + # Load the image + if type(image) == np.ndarray: + image = Image.fromarray(image) + + image = copy.deepcopy(image) + width, height = image.size + + # Calculate max_text_width and font_size based on image dimensions and total number of characters + total_chars = len(text) + max_text_width = int(0.33 * width) + font_size = int(height * font_size_ratio) + + # Load the font + font = ImageFont.truetype(font_path, font_size) + + # Wrap the text to fit within the max_text_width + lines = wrap_text(text, font, max_text_width) + text_width, text_height = font.getsize(lines[0]) + text_height = text_height * len(lines) + + # Define bubble frame dimensions + padding = 10 + bubble_width = text_width + 2 * padding + bubble_height = text_height + 2 * padding + + # Create a new image for the bubble frame + bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 255, 255, 0)) + + # Draw the bubble frame on the new image + draw = ImageDraw.Draw(bubble) + draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2) + + # Draw the wrapped text line by line + y_text = padding + for line in lines: + draw.text((padding, y_text), line, font=font, fill=(255, 255, 255, 255)) + y_text += font.getsize(line)[1] + + # Calculate the bubble frame position + x, y = point + if x + bubble_width > width: + x = width - bubble_width + if y + bubble_height > height: + y = height - bubble_height + + # Paste the bubble frame onto the image + image.paste(bubble, (x, y), bubble) + return image \ No newline at end of file diff --git a/segmenter/__init__.py b/segmenter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4682fd3c7f03127b35a992fcbcd028bc193ea390 --- /dev/null +++ b/segmenter/__init__.py @@ -0,0 +1,6 @@ +from segmenter.base_segmenter import BaseSegmenter + + +def build_segmenter(type, device, args=None): + if type == 'base': + return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features) \ No newline at end of file diff --git a/segmenter/base_segmenter.py b/segmenter/base_segmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..f78090377994297c4ae3b203f3680253b4799e93 --- /dev/null +++ b/segmenter/base_segmenter.py @@ -0,0 +1,153 @@ +import time +import torch +import cv2 +from PIL import Image, ImageDraw, ImageOps +import numpy as np +from typing import Union +from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator +import matplotlib.pyplot as plt +import PIL + +class BaseSegmenter: + def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True): + print(f"Initializing BaseSegmenter to {device}") + self.device = device + self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 + self.processor = None + self.model_type = model_type + self.checkpoint = checkpoint + self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint) + self.model.to(device=self.device) + self.reuse_feature = reuse_feature + self.predictor = SamPredictor(self.model) + self.mask_generator = SamAutomaticMaskGenerator(self.model) + self.image_embedding = None + self.image = None + + + @torch.no_grad() + def set_image(self, image: Union[np.ndarray, Image.Image, str]): + if type(image) == str: # input path + image = Image.open(image) + image = np.array(image) + elif type(image) == Image.Image: + image = np.array(image) + self.image = image + if self.reuse_feature: + self.predictor.set_image(image) + self.image_embedding = self.predictor.get_image_embedding() + print(self.image_embedding.shape) + + + @torch.no_grad() + def inference(self, image, control): + if 'everything' in control['prompt_type']: + masks = self.mask_generator.generate(image) + new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks]) + return new_masks + else: + if not self.reuse_feature: + self.set_image(image) + self.predictor.set_image(self.image) + else: + assert self.image_embedding is not None + self.predictor.features = self.image_embedding + + if 'mutimask_output' in control: + masks, scores, logits = self.predictor.predict( + point_coords = np.array(control['input_point']), + point_labels = np.array(control['input_label']), + multimask_output = True, + ) + elif 'input_boxes' in control: + transformed_boxes = self.predictor.transform.apply_boxes_torch( + torch.tensor(control["input_boxes"], device=self.predictor.device), + image.shape[:2] + ) + masks, _, _ = self.predictor.predict_torch( + point_coords=None, + point_labels=None, + boxes=transformed_boxes, + multimask_output=False, + ) + masks = masks.squeeze(1).cpu().numpy() + + else: + input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None + input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None + input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None + + masks, scores, logits = self.predictor.predict( + point_coords = input_point, + point_labels = input_label, + box = input_box, + multimask_output = False, + ) + + if 0 in control['input_label']: + mask_input = logits[np.argmax(scores), :, :] + masks, scores, logits = self.predictor.predict( + point_coords=input_point, + point_labels=input_label, + box = input_box, + mask_input=mask_input[None, :, :], + multimask_output=False, + ) + + return masks + +if __name__ == "__main__": + image_path = 'segmenter/images/truck.jpg' + prompts = [ + # { + # "prompt_type":["click"], + # "input_point":[[500, 375]], + # "input_label":[1], + # "multimask_output":"True", + # }, + { + "prompt_type":["click"], + "input_point":[[1000, 600], [1325, 625]], + "input_label":[1, 0], + }, + # { + # "prompt_type":["click", "box"], + # "input_box":[425, 600, 700, 875], + # "input_point":[[575, 750]], + # "input_label": [0] + # }, + # { + # "prompt_type":["box"], + # "input_boxes": [ + # [75, 275, 1725, 850], + # [425, 600, 700, 875], + # [1375, 550, 1650, 800], + # [1240, 675, 1400, 750], + # ] + # }, + # { + # "prompt_type":["everything"] + # }, + ] + + init_time = time.time() + segmenter = BaseSegmenter( + device='cuda', + # checkpoint='sam_vit_h_4b8939.pth', + checkpoint='segmenter/sam_vit_h_4b8939.pth', + model_type='vit_h', + reuse_feature=True + ) + print(f'init time: {time.time() - init_time}') + + image_path = 'test_img/img2.jpg' + infer_time = time.time() + for i, prompt in enumerate(prompts): + print(f'{prompt["prompt_type"]} mode') + image = Image.open(image_path) + segmenter.set_image(np.array(image)) + masks = segmenter.inference(np.array(image), prompt) + Image.fromarray(masks[0]).save('seg.png') + print(masks.shape) + + print(f'infer time: {time.time() - infer_time}') diff --git a/segmenter/images/truck.jpg b/segmenter/images/truck.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6b98688c3c84981200c06259b8d54820ebf85660 Binary files /dev/null and b/segmenter/images/truck.jpg differ diff --git a/segmenter/readme.md b/segmenter/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..ede6e7556bbd3c968165a331446df20398a572fa --- /dev/null +++ b/segmenter/readme.md @@ -0,0 +1,68 @@ +### Prepare SAM +``` +pip install git+https://github.com/facebookresearch/segment-anything.git +``` +or +``` +git clone git@github.com:facebookresearch/segment-anything.git +cd segment-anything; pip install -e . +``` + +``` +pip install opencv-python pycocotools matplotlib onnxruntime onnx +``` +### Download the checkpoint: + +https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth + +### Inference + +The prompts are in json format: + +``` +prompts = [ + { + "prompt_type":["click"], + "input_point":[[500, 375]], + "input_label":[1], + "multimask_output":"True", + }, + { + "prompt_type":["click"], + "input_point":[[500, 375], [1125, 625]], + "input_label":[1, 0], + }, + { + "prompt_type":["click", "box"], + "input_box":[425, 600, 700, 875], + "input_point":[[575, 750]], + "input_label": [0] + }, + { + "prompt_type":["box"], + "input_boxes": [ + [75, 275, 1725, 850], + [425, 600, 700, 875], + [1375, 550, 1650, 800], + [1240, 675, 1400, 750], + ] + }, + { + "prompt_type":["everything"] + }, + ] +``` + +In `base_segmenter.py`: +``` +segmenter = BaseSegmenter( + device='cuda', + checkpoint='sam_vit_h_4b8939.pth', + model_type='vit_h' + ) + +for i, prompt in enumerate(prompts): + masks = segmenter.inference(image_path, prompt) +``` + +Outputs are masks (True and False numpy Matrix), shape: (num of masks, height, weight) \ No newline at end of file diff --git a/test_img/img1.jpg b/test_img/img1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..83c0c9eb9f5026fdb7a7f49fba081d4764ce0515 Binary files /dev/null and b/test_img/img1.jpg differ diff --git a/test_img/img1.jpg.raw_mask.png b/test_img/img1.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..ba811712737fa16ca0fd79aa981ff8b6f65d6d5f Binary files /dev/null and b/test_img/img1.jpg.raw_mask.png differ diff --git a/test_img/img10.jpg b/test_img/img10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b51fde5fbe4d06c4295270b100f8861bbb02a870 Binary files /dev/null and b/test_img/img10.jpg differ diff --git a/test_img/img10.jpg.raw_mask.png b/test_img/img10.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..9f9145d5c6f0c671d2c0d44f860ccc20aaf8e33f Binary files /dev/null and b/test_img/img10.jpg.raw_mask.png differ diff --git a/test_img/img11.jpg b/test_img/img11.jpg new file mode 100644 index 0000000000000000000000000000000000000000..698333f481ea34d1ebb379f4d5802939072c83db Binary files /dev/null and b/test_img/img11.jpg differ diff --git a/test_img/img12.jpg b/test_img/img12.jpg new file mode 100644 index 0000000000000000000000000000000000000000..20a3789bad40238cc90cca7b8e0049aaad1e1dbd Binary files /dev/null and b/test_img/img12.jpg differ diff --git a/test_img/img12.jpg.raw_mask.png b/test_img/img12.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..7c857a906c303eb038ce7af6eb37d69762301871 Binary files /dev/null and b/test_img/img12.jpg.raw_mask.png differ diff --git a/test_img/img13.jpg b/test_img/img13.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9374e1fa87e3103869e727a8c56fb22525adb715 Binary files /dev/null and b/test_img/img13.jpg differ diff --git a/test_img/img13.jpg.raw_mask.png b/test_img/img13.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..23f9dcebc3b52026ab6ce27fba85b48059b1bb8c Binary files /dev/null and b/test_img/img13.jpg.raw_mask.png differ diff --git a/test_img/img14.jpg b/test_img/img14.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f60ad955110a5238e80ef93af7bbce03a4322e48 Binary files /dev/null and b/test_img/img14.jpg differ diff --git a/test_img/img14.jpg.raw_mask.png b/test_img/img14.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..da46cc403fc9eedd021e728db7921c91c5f43e05 Binary files /dev/null and b/test_img/img14.jpg.raw_mask.png differ diff --git a/test_img/img15.jpg b/test_img/img15.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab3ef5ec0c62965253ab782d0e0dbf02929588af Binary files /dev/null and b/test_img/img15.jpg differ diff --git a/test_img/img15.jpg.raw_mask.png b/test_img/img15.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..e5ebf143e53a9a74f471f8ee2b2209fc72854463 Binary files /dev/null and b/test_img/img15.jpg.raw_mask.png differ diff --git a/test_img/img16.jpg b/test_img/img16.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4871a9f5a7b300f34e99337097d2e178ad649ed9 Binary files /dev/null and b/test_img/img16.jpg differ diff --git a/test_img/img16.jpg.raw_mask.png b/test_img/img16.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..62739f20714ef4b48b24e24d77311d5a60e3268d Binary files /dev/null and b/test_img/img16.jpg.raw_mask.png differ diff --git a/test_img/img17.jpg b/test_img/img17.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b5534d3978b826e88ba0e88c9a9953062cbe57a Binary files /dev/null and b/test_img/img17.jpg differ diff --git a/test_img/img18.jpg b/test_img/img18.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db9215dfbaefa5f1c64c03dd1b928de1c6117ff8 --- /dev/null +++ b/test_img/img18.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e02c393a23aadd1304497e3a9b41144df166d1cfda33ea3e00eed94e27da3aa4 +size 1372251 diff --git a/test_img/img19.jpg b/test_img/img19.jpg new file mode 100644 index 0000000000000000000000000000000000000000..abbe797820425e442a2a6b99a22b327aee3e9961 Binary files /dev/null and b/test_img/img19.jpg differ diff --git a/test_img/img2.jpg b/test_img/img2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..583f69ec771a6f562e8dd9511b61fb9034a1af64 Binary files /dev/null and b/test_img/img2.jpg differ diff --git a/test_img/img2.jpg.raw_mask.png b/test_img/img2.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..d4d55444af67e7831aee27d3df86c5150f680e70 Binary files /dev/null and b/test_img/img2.jpg.raw_mask.png differ diff --git a/test_img/img20.jpg b/test_img/img20.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1c75bd821f2beb8cd72d56ad1e5f9064c96066c7 Binary files /dev/null and b/test_img/img20.jpg differ diff --git a/test_img/img21.jpg b/test_img/img21.jpg new file mode 100644 index 0000000000000000000000000000000000000000..98462cd6c0a8cfdbfe158f4484843d0a320d5dce Binary files /dev/null and b/test_img/img21.jpg differ diff --git a/test_img/img22.jpg b/test_img/img22.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6b898f4558d34b4a3fcd44dcffda58bbea2b942 --- /dev/null +++ b/test_img/img22.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5c5159bf7114d08967f95475176670043115b157bf700efa34190260cd917662 +size 1025438 diff --git a/test_img/img23.jpg b/test_img/img23.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b070a469b4009a565f167784c552ce9886769e8 Binary files /dev/null and b/test_img/img23.jpg differ diff --git a/test_img/img24.jpg b/test_img/img24.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c90f0967fe7878cba26a72014e0bf377f0fd9c7d Binary files /dev/null and b/test_img/img24.jpg differ diff --git a/test_img/img25.jpg b/test_img/img25.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad24ad0005f04a31fef9077793768bf68fc3276c Binary files /dev/null and b/test_img/img25.jpg differ diff --git a/test_img/img27.jpg b/test_img/img27.jpg new file mode 100644 index 0000000000000000000000000000000000000000..08cac0fa26a959dbd2a4fb33043a75cb3a1b6d06 Binary files /dev/null and b/test_img/img27.jpg differ diff --git a/test_img/img28.jpg b/test_img/img28.jpg new file mode 100644 index 0000000000000000000000000000000000000000..31c7c4a57e21b8c3cf82ee9783179c3410472ed6 Binary files /dev/null and b/test_img/img28.jpg differ diff --git a/test_img/img29.jpg b/test_img/img29.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5fbab4d5eafcd33d33db45669cc1a9ce7432f111 Binary files /dev/null and b/test_img/img29.jpg differ diff --git a/test_img/img3.jpg b/test_img/img3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..deeafdbc1d4ac40426f75ee7395ecd82025d6e95 Binary files /dev/null and b/test_img/img3.jpg differ diff --git a/test_img/img30.jpg b/test_img/img30.jpg new file mode 100644 index 0000000000000000000000000000000000000000..060d18f3481662e618e1cf376281fc39bfd0b41d Binary files /dev/null and b/test_img/img30.jpg differ diff --git a/test_img/img31.jpg b/test_img/img31.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bef87cab085b7bd4f1025090d875bc09bc2e5c96 Binary files /dev/null and b/test_img/img31.jpg differ diff --git a/test_img/img32.jpg b/test_img/img32.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aa916c29c7839093f4092e8c258824a771901cd6 Binary files /dev/null and b/test_img/img32.jpg differ diff --git a/test_img/img33.jpg b/test_img/img33.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d91468d1488d494904ea0068dc811ba78aa69339 Binary files /dev/null and b/test_img/img33.jpg differ diff --git a/test_img/img34.jpg b/test_img/img34.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f05ceaa1b35c6b74e196d979efadc5f4b79b6170 Binary files /dev/null and b/test_img/img34.jpg differ diff --git a/test_img/img4.jpg b/test_img/img4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a204f5a7567288216ec7e18a5223e677ab397b36 Binary files /dev/null and b/test_img/img4.jpg differ diff --git a/test_img/img4.jpg.raw_mask.png b/test_img/img4.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..dba4577e87cd6e52870525af2cdbde07940f100a Binary files /dev/null and b/test_img/img4.jpg.raw_mask.png differ diff --git a/test_img/img5.jpg b/test_img/img5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..80e2e7e4b9505a1528b8d319d6b2efcbde16a9cf Binary files /dev/null and b/test_img/img5.jpg differ diff --git a/test_img/img5.jpg.raw_mask.png b/test_img/img5.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..d1924854e703bd95f8587436bf9297ccf775a041 Binary files /dev/null and b/test_img/img5.jpg.raw_mask.png differ diff --git a/test_img/img6.jpg b/test_img/img6.jpg new file mode 100644 index 0000000000000000000000000000000000000000..35d44f6a08b0fab2b38efb68c65e7528d65aca48 Binary files /dev/null and b/test_img/img6.jpg differ diff --git a/test_img/img6.jpg.raw_mask.png b/test_img/img6.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..42fd658b5c69644b742d56ae567207e1898f6a7e Binary files /dev/null and b/test_img/img6.jpg.raw_mask.png differ diff --git a/test_img/img7.jpg b/test_img/img7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..679431b782257372a0bbe19ab701c308d114f0d7 Binary files /dev/null and b/test_img/img7.jpg differ diff --git a/test_img/img7.jpg.raw_mask.png b/test_img/img7.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..a15527829eac9c027c42e60c4702bc70e57db460 Binary files /dev/null and b/test_img/img7.jpg.raw_mask.png differ diff --git a/test_img/img8.jpg b/test_img/img8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..62ef2d4a1c9fb498fc3f2e3f8928fd3626832d0b Binary files /dev/null and b/test_img/img8.jpg differ diff --git a/test_img/img8.jpg.raw_mask.png b/test_img/img8.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..285410c86e57c3905b223ec392e61570a98c369b Binary files /dev/null and b/test_img/img8.jpg.raw_mask.png differ diff --git a/test_img/img9.jpg b/test_img/img9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..49acb36e2ad271dac1fc629cd10440c14954e70e Binary files /dev/null and b/test_img/img9.jpg differ diff --git a/test_img/img9.jpg.raw_mask.png b/test_img/img9.jpg.raw_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..127a194bbb9088338966e1db84fd9c4dca94afdf Binary files /dev/null and b/test_img/img9.jpg.raw_mask.png differ diff --git a/test_img/painter_input_image.jpg b/test_img/painter_input_image.jpg new file mode 100644 index 0000000000000000000000000000000000000000..deeafdbc1d4ac40426f75ee7395ecd82025d6e95 Binary files /dev/null and b/test_img/painter_input_image.jpg differ diff --git a/test_img/painter_input_mask.jpg b/test_img/painter_input_mask.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0720afed9caf1e4e8b1864a86a7004c43307d845 Binary files /dev/null and b/test_img/painter_input_mask.jpg differ diff --git a/test_img/painter_output_image.png b/test_img/painter_output_image.png new file mode 100644 index 0000000000000000000000000000000000000000..40b97bb859e559e82c03fff625c29f9b391723cc Binary files /dev/null and b/test_img/painter_output_image.png differ diff --git a/text_refiner/README.md b/text_refiner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..813684e94db2941fbf10ea1fafd55a75a3560939 --- /dev/null +++ b/text_refiner/README.md @@ -0,0 +1,8 @@ +# Install +* python >= 3.8.1 + +```bash +pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html # CUDA version could be different +pip install openai pillow transformers +pip install langchain==0.0.101 +``` diff --git a/text_refiner/__init__.py b/text_refiner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c17dd160c3866ac0aa9327558e62ffdcfc43779d --- /dev/null +++ b/text_refiner/__init__.py @@ -0,0 +1,6 @@ +from text_refiner.text_refiner import TextRefiner + + +def build_text_refiner(type, device, args=None): + if type == 'base': + return TextRefiner(device) \ No newline at end of file diff --git a/text_refiner/text_refiner.py b/text_refiner/text_refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc9ecc6b099662324771885439f2dd5f20a21b2 --- /dev/null +++ b/text_refiner/text_refiner.py @@ -0,0 +1,87 @@ +from langchain.llms.openai import OpenAI +import torch +from PIL import Image, ImageDraw, ImageOps +from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering +import pdb + +class TextRefiner: + def __init__(self, device): + print(f"Initializing TextRefiner to {device}") + try: + self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0) + except: + print('Openai api key is NOT given') + self.prompt_tag = { + "imagination": {"True": "could", + "False": "could not"} + } + self.short_prompts = { + "length": "around {length} words", + "sentiment": "of {sentiment} sentiment", + "language": "in {language}", + } + + self.long_prompts = { + "imagination": "The new sentence could extend the original description by using your imagination to create additional details, or think about what might have happened before or after the scene in the image, but should not conflict with the original sentence", + } + + self.wiki_prompts = "I want you to act as a Wikipedia page. I will give you a sentence and you will parse the single main object in the sentence and provide a summary of that object in the format of a Wikipedia page. Your summary should be informative and factual, covering the most important aspects of the object. Start your summary with an introductory paragraph that gives an overview of the object. The overall length of the response should be around 100 words. You should not describe the parsing process and only provide the final summary. The sentence is \"{query}\"." + + self.control_prompts = "As a text reviser, you will convert an image description into a new sentence or long paragraph. The new text is {prompts}. {long_prompts} The sentence is \"{query}\" (give me the revised sentence only)" + + def parse(self, response): + out = response.strip() + return out + + def parse2(self, response): + out = response.strip() + return out + + def prepare_input(self, query, short_prompts, long_prompts): + input = self.control_prompts.format(**{'prompts': ', '.join(short_prompts), 'long_prompts': '. '.join(long_prompts), 'query': query}) + print('prompt: ', input) + return input + + def inference(self, query: str, controls: dict, context: list=[]): + """ + query: the caption of the region of interest, generated by captioner + controls: a dict of control singals, e.g., {"length": 5, "sentiment": "positive"} + """ + prompts = [] + long_prompts = [] + for control, value in controls.items(): + # if control in self.prompt_tag: + # value = self.prompt_tag[control][value] + if control in self.short_prompts: + prompts.append(self.short_prompts[control].format(**{control: value})) + else: + if value in [True, "True", "true"]: + long_prompts.append(self.long_prompts[control]) + input = self.prepare_input(query, prompts, long_prompts) + response = self.llm(input) + response = self.parse(response) + + tmp_configs = {"query": query} + prompt_wiki = self.wiki_prompts.format(**tmp_configs) + response_wiki = self.llm(prompt_wiki) + response_wiki = self.parse2(response_wiki) + out = { + 'raw_caption': query, + 'caption': response, + 'wiki': response_wiki + } + print(out) + return out + +if __name__ == "__main__": + model = TextRefiner(device='cpu') + controls = { + "length": "30", + "sentiment": "negative", + # "imagination": "True", + "imagination": "False", + "language": "English", + } + # model.inference(query='a dog is sitting on a brown bench', controls=controls) + model.inference(query='a cat is sleeping', controls=controls) + \ No newline at end of file diff --git a/tools.py b/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..360cb59d674a4255fe4a216d95c2041f7f17e6a9 --- /dev/null +++ b/tools.py @@ -0,0 +1,181 @@ +import cv2 +import numpy as np +from PIL import Image + + +def colormap(rgb=True): + color_list = np.array( + [ + 0.000, 0.000, 0.000, + 1.000, 1.000, 1.000, + 1.000, 0.498, 0.313, + 0.392, 0.581, 0.929, + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) * 255 + if not rgb: + color_list = color_list[:, ::-1] + return color_list + + +color_list = colormap() +color_list = color_list.astype('uint8').tolist() + + +def gauss_filter(kernel_size, sigma): + max_idx = kernel_size // 2 + idx = np.linspace(-max_idx, max_idx, kernel_size) + Y, X = np.meshgrid(idx, idx) + gauss_filter = np.exp(-(X**2 + Y**2) / (2*sigma**2)) + gauss_filter /= np.sum(np.sum(gauss_filter)) + + return gauss_filter + + +def vis_add_mask(image, mask, color, alpha, kernel_size): + color = np.array(color) + mask = mask.astype('float').copy() + mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha) + + for i in range(3): + image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask) + + return image + + +def vis_add_mask_wo_blur(image, mask, color, alpha): + color = np.array(color) + mask = mask.astype('float').copy() + for i in range(3): + image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask) + return image + + +def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1): + """ + Input: + input_image: numpy array + input_mask: numpy array + background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing + background_blur_radius: radius of background blur, must be odd number + contour_width: width of mask contour, must be odd number + contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others + contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted + + Output: + painted_image: numpy array + """ + assert input_image.shape[:2] == input_mask.shape, 'different shape' + assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD' + + # 0: background, 1: foreground + input_mask[input_mask>0] = 255 + + # mask background + painted_image = vis_add_mask(input_image, input_mask, color_list[0], background_alpha, background_blur_radius) # black for background + # mask contour + contour_mask = input_mask.copy() + contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction + # widden contour + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width)) + contour_mask = cv2.dilate(contour_mask, kernel) + painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width) + + return painted_image + + +if __name__ == '__main__': + + background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing + background_blur_radius = 35 # radius of background blur, must be odd number + contour_width = 7 # contour width, must be odd number + contour_color = 3 # id in color map, 0: black, 1: white, >1: others + contour_alpha = 1 # transparency of background, 0: no contour highlighted + + # load input image and mask + input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB')) + input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P')) + + # paint + painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha) + + # save + painted_image = Image.fromarray(painted_image) + painted_image.save('./test_img/painter_output_image.png')