diff --git a/app.py b/app.py index c185365f71f3912b4d5ca1f075abc21040c935e5..63eb8ba17e3402f2912ad73c3517cce77c9681a8 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,119 @@ import gradio as gr -gr.load("models/q-future/one-align").launch() \ No newline at end of file +import argparse +import datetime +import json +import os +import time + +import gradio as gr +import requests +from PIL import Image + +from q_align.model.builder import load_pretrained_model + +from q_align.conversation import (default_conversation, conv_templates, + SeparatorStyle) +from q_align.constants import LOGDIR +from q_align.utils import (build_logger, server_error_msg, + violates_moderation, moderation_msg) + +from q_align.evaluate.scorer import QAlignScorer, QAlignAestheticScorer, QAlignVideoScorer + +import gradio as gr + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +pretrained="q-future/one-align" +device="cuda:0" +tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + +iqa_scorer = QAlignScorer(tokenizer=tokenizer, model=model, image_processor=image_processor) +iaa_scorer = QAlignAestheticScorer(tokenizer=tokenizer, model=model, image_processor=image_processor) +vqa_scorer = QAlignVideoScorer(tokenizer=tokenizer, model=model, image_processor=image_processor) + +scorers = {"Image Aesthetics (IAA)": iaa_scorer, "Image Quality (IQA)": iqa_scorer, "Video Quality (VQA)": vqa_scorer} + +LEVELS = ["excellent (5)", "good (4)", "fair (3)", "poor (2)", "bad (1)"] +scores = [5,4,3,2,1] +def image_classifier(input_img, input_vid, scorer_type): + if scorer_type is None: + scorer_type = "Image Quality (IQA)" + this_scorer = scorers[scorer_type] + if input_vid is not None: + input_ = load_video(input_vid) + elif input_img is not None: + input_ = [input_img] + if "Video" in scorer_type: + input_ = [input_] + probs = this_scorer(input_).mean(0).tolist() + prob_dict = {LEVEL: prob for LEVEL, prob in zip(LEVELS, probs)} + score = sum([prob * score for score, prob in zip(scores, probs)]) + return prob_dict, score + +title_markdown = (""" + +

Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined Levels

+ +

One Unified Model for Visual scoring.

+ +
+ Haoning Wu1*+, + Zicheng Zhang2*, + Weixia Zhang2, + Chaofeng Chen1, + Liang Liao1, + Chunyi Li2, +
+ + +
+ Yixuan Gao2, + Annan Wang1, + Erli Zhang1, + Wenxiu Sun3, + Qiong Yan3, + Xiongkuo Min2, + Guangtao Zhai2#, + Weisi Lin1# +
+ +
+ 1Nanyang Technological University, 2Shanghai Jiao Tong University, 3Sensetime Research +
+
+*Equal contribution. +Project Lead. #Corresponding author(s). +
+ +

If you like the OneScorer, please give us a star ✨ on GitHub for latest update.

+ +
+
+ + + + +
+
+ +""") + + +input_img = gr.Image(type='pil', label="Upload an Image") +input_vid = gr.Video(label="Upload a Video (will INGORE the image if a video is uploaded)", info="If a video is uploaded, the image uploaded will be ignored.") + +labels = gr.Label(label="Probabilities of rating levels:") +number = gr.Number(label="Output score:", info="Range in [1,5]. Higher is better.") +demo = gr.Interface(fn=image_classifier, inputs=[input_img, input_vid, gr.Radio(["Image Aesthetics (IAA)", "Image Quality (IQA)", "Video Quality (VQA)"], label="Task", info="Which Scorer will you need?"),], outputs=[labels, number], title="OneScorer", description=title_markdown) +demo.launch(share=True) \ No newline at end of file diff --git a/q_align/.ipynb_checkpoints/utils-checkpoint.py b/q_align/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..378dbc81cce765dea2cd138ae50f242fc557c2c9 --- /dev/null +++ b/q_align/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,128 @@ +import datetime +import logging +import logging.handlers +import os +import sys + +import requests + + + +from q_align.constants import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" \ No newline at end of file diff --git a/q_align/__init__.py b/q_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e4234303e2ad423b55342b0000d9dfc5b88df3 --- /dev/null +++ b/q_align/__init__.py @@ -0,0 +1 @@ +from .model import MPLUGOwl2LlamaForCausalLM \ No newline at end of file diff --git a/q_align/__pycache__/__init__.cpython-310.pyc b/q_align/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7db7c91534d4d9fcba53458c73de8b7b93319c0 Binary files /dev/null and b/q_align/__pycache__/__init__.cpython-310.pyc differ diff --git a/q_align/__pycache__/__init__.cpython-311.pyc b/q_align/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43eabac5150ca8ac52e0fe20c593ea49d166b30c Binary files /dev/null and b/q_align/__pycache__/__init__.cpython-311.pyc differ diff --git a/q_align/__pycache__/__init__.cpython-39.pyc b/q_align/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9711d49034200f623d0a4231eca427b704a13af9 Binary files /dev/null and b/q_align/__pycache__/__init__.cpython-39.pyc differ diff --git a/q_align/__pycache__/constants.cpython-310.pyc b/q_align/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40d379600f04196fc4637597294b0a48dd6c4edc Binary files /dev/null and b/q_align/__pycache__/constants.cpython-310.pyc differ diff --git a/q_align/__pycache__/constants.cpython-311.pyc b/q_align/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed26139984658aa167829ebb17321ac8eabb7f79 Binary files /dev/null and b/q_align/__pycache__/constants.cpython-311.pyc differ diff --git a/q_align/__pycache__/constants.cpython-39.pyc b/q_align/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a78884864b830cccdb26eb71c54b6f74944533e5 Binary files /dev/null and b/q_align/__pycache__/constants.cpython-39.pyc differ diff --git a/q_align/__pycache__/conversation.cpython-310.pyc b/q_align/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e3817f60cf760318094a82981eb13d9f39b975d Binary files /dev/null and b/q_align/__pycache__/conversation.cpython-310.pyc differ diff --git a/q_align/__pycache__/conversation.cpython-311.pyc b/q_align/__pycache__/conversation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..570ae579548d6bcbe60e21df289d19c68aa03644 Binary files /dev/null and b/q_align/__pycache__/conversation.cpython-311.pyc differ diff --git a/q_align/__pycache__/conversation.cpython-39.pyc b/q_align/__pycache__/conversation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29aa67455004b8251e5a63a880ff144831a37a87 Binary files /dev/null and b/q_align/__pycache__/conversation.cpython-39.pyc differ diff --git a/q_align/__pycache__/mm_utils.cpython-310.pyc b/q_align/__pycache__/mm_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..519d0457d58e50de78d109b3004e2464a6d67d2c Binary files /dev/null and b/q_align/__pycache__/mm_utils.cpython-310.pyc differ diff --git a/q_align/__pycache__/mm_utils.cpython-311.pyc b/q_align/__pycache__/mm_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47702f8c81d7079d18936516385cc5a4ddad53bb Binary files /dev/null and b/q_align/__pycache__/mm_utils.cpython-311.pyc differ diff --git a/q_align/__pycache__/mm_utils.cpython-39.pyc b/q_align/__pycache__/mm_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a14b149ddd57f0c91444b14f6719c28b8bac6b9f Binary files /dev/null and b/q_align/__pycache__/mm_utils.cpython-39.pyc differ diff --git a/q_align/__pycache__/utils.cpython-311.pyc b/q_align/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abd682a6a83a001bf72d71cd273b0272972fb6c7 Binary files /dev/null and b/q_align/__pycache__/utils.cpython-311.pyc differ diff --git a/q_align/constants.py b/q_align/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..b632a10f2c053c72fae8d61a1fa10fa52aa0f4e8 --- /dev/null +++ b/q_align/constants.py @@ -0,0 +1,9 @@ +CONTROLLER_HEART_BEAT_EXPIRATION = 30 +WORKER_HEART_BEAT_INTERVAL = 15 + +LOGDIR = "./demo_logs" + +# Model Constants +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +DEFAULT_IMAGE_TOKEN = "<|image|>" diff --git a/q_align/conversation.py b/q_align/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f6f3f518f8b4426493ff7dacb0f6b70ae56c09 --- /dev/null +++ b/q_align/conversation.py @@ -0,0 +1,301 @@ +import dataclasses +from enum import auto, Enum +from typing import List, Tuple +from q_align.constants import DEFAULT_IMAGE_TOKEN + +class SeparatorStyle(Enum): + """Different separator style.""" + SINGLE = auto() + TWO = auto() + TWO_NO_SYS = auto() + MPT = auto() + PLAIN = auto() + LLAMA_2 = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + roles: List[str] + messages: List[List[str]] + offset: int + sep_style: SeparatorStyle = SeparatorStyle.SINGLE + sep: str = "###" + sep2: str = None + version: str = "Unknown" + + skip_next: bool = False + + def get_prompt(self): + messages = self.messages + if len(messages) > 0 and type(messages[0][1]) is tuple: + messages = self.messages.copy() + init_role, init_msg = messages[0].copy() + # init_msg = init_msg[0].replace("", "").strip() + # if 'mmtag' in self.version: + # messages[0] = (init_role, init_msg) + # messages.insert(0, (self.roles[0], "")) + # messages.insert(1, (self.roles[1], "Received.")) + # else: + # messages[0] = (init_role, "\n" + init_msg) + init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip() + messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg) + + if self.sep_style == SeparatorStyle.SINGLE: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + self.sep + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.TWO_NO_SYS: + seps = [self.sep, self.sep2] + ret = "" + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + elif self.sep_style == SeparatorStyle.MPT: + ret = self.system + self.sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + elif self.sep_style == SeparatorStyle.LLAMA_2: + wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" + wrap_inst = lambda msg: f"[INST] {msg} [/INST]" + ret = "" + + for i, (role, message) in enumerate(messages): + if i == 0: + assert message, "first message should not be none" + assert role == self.roles[0], "first message should come from user" + if message: + if type(message) is tuple: + message, _, _ = message + if i == 0: message = wrap_sys(self.system) + message + if i % 2 == 0: + message = wrap_inst(message) + ret += self.sep + message + else: + ret += " " + message + " " + self.sep2 + else: + ret += "" + ret = ret.lstrip(self.sep) + elif self.sep_style == SeparatorStyle.PLAIN: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(messages): + if message: + if type(message) is tuple: + message, _, _ = message + ret += message + seps[i % 2] + else: + ret += "" + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + return ret + + def append_message(self, role, message): + self.messages.append([role, message]) + + def get_images(self, return_pil=False): + images = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + from PIL import Image + msg, image, image_process_mode = msg + if image_process_mode == "Pad": + def expand2square(pil_img, background_color=(122, 116, 104)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image) + elif image_process_mode in ["Default", "Crop"]: + pass + elif image_process_mode == "Resize": + image = image.resize((336, 336)) + else: + raise ValueError(f"Invalid image_process_mode: {image_process_mode}") + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if longest_edge != max(image.size): + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + if return_pil: + images.append(image) + else: + buffered = BytesIO() + image.save(buffered, format="PNG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + images.append(img_b64_str) + return images + + def to_gradio_chatbot(self): + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + import base64 + from io import BytesIO + msg, image, image_process_mode = msg + max_hw, min_hw = max(image.size), min(image.size) + aspect_ratio = max_hw / min_hw + max_len, min_len = 800, 400 + shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) + longest_edge = int(shortest_edge * aspect_ratio) + W, H = image.size + if H > W: + H, W = longest_edge, shortest_edge + else: + H, W = shortest_edge, longest_edge + image = image.resize((W, H)) + buffered = BytesIO() + image.save(buffered, format="JPEG") + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'user upload image' + msg = img_str + msg.replace('<|image|>', '').strip() + ret.append([msg, None]) + else: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def copy(self): + return Conversation( + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + version=self.version) + + def dict(self): + if len(self.get_images()) > 0: + return { + "system": self.system, + "roles": self.roles, + "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + return { + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "sep": self.sep, + "sep2": self.sep2, + } + + +conv_vicuna_v0 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ("Human", "What are the key differences between renewable and non-renewable energy sources?"), + ("Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") + ), + offset=2, + sep_style=SeparatorStyle.SINGLE, + sep="###", +) + +conv_vicuna_v1 = Conversation( + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO, + sep=" ", + sep2="", +) + +conv_mplug_owl2 = Conversation( + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("USER", "ASSISTANT"), + version="v1", + messages=(), + offset=0, + sep_style=SeparatorStyle.TWO_NO_SYS, + sep=" ", + sep2="", +) + +# default_conversation = conv_vicuna_v1 +default_conversation = conv_mplug_owl2 +conv_templates = { + "default": conv_vicuna_v0, + "v0": conv_vicuna_v0, + "v1": conv_vicuna_v1, + "vicuna_v1": conv_vicuna_v1, + "mplug_owl2": conv_mplug_owl2, +} + + +if __name__ == "__main__": + print(default_conversation.get_prompt()) \ No newline at end of file diff --git a/q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..64f17e29dacae2c145c0b10d1fe0dcdc27b8f6b9 --- /dev/null +++ b/q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py @@ -0,0 +1,164 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +from scipy.stats import spearmanr, pearsonr + + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + +def wa5(logits): + import numpy as np + logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]]) + probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) + return np.inner(probs, np.array([1,0.75,0.5,0.25,0.])) + + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_image(image_file): + if image_file.startswith('http://') or image_file.startswith('https://'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_path = "playground/data/" + + + json_prefix = "playground/data/test_jsons/" + jsons = [ + json_prefix + "test_ava.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the aesthetics of this image?" + + conv = conv_templates[conv_mode].copy() + inp = DEFAULT_IMAGE_TOKEN + inp + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The aesthetics of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for json_ in jsons: + with open(json_) as f: + iqadata = json.load(f) + + image_tensors = [] + batch_data = [] + prs, gts = [], [] + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + filename = llddata["image"] + llddata["logits"] = defaultdict(float) + + + + image = load_image(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + image_tensors.append(image_tensor) + batch_data.append(llddata) + + if i % 8 == 7 or i == len(iqadata) - 1: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(len(image_tensors), 1), + images=torch.cat(image_tensors, 0))["logits"][:,-1] + + for j, xllddata in enumerate(batch_data): + for tok, id_ in zip(toks, ids_): + xllddata["logits"][tok] += output_logits[j,id_].item() + xllddata["score"] = wa5(xllddata["logits"]) + # print(llddata) + prs.append(xllddata["score"]) + gts.append(xllddata["gt_score"]) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf: + json.dump(xllddata, wf) + + image_tensors = [] + batch_data = [] + + #if i > 0 and i % 200 == 0: + # print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0]) + print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1aaf1a7ac9ec4ee21a48ee96bfc92fe742559c24 --- /dev/null +++ b/q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py @@ -0,0 +1,150 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +from decord import VideoReader + + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_video(video_file): + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_paths = [ + "playground/data/", + "playground/data/", + "playground/data/KoNViD_1k_videos/", + "playground/data/maxwell/", + ] + + json_prefix = "playground/data/test_jsons/" + jsons = [ + json_prefix + "test_lsvq.json", + json_prefix + "test_lsvq_1080p.json", + json_prefix + "konvid.json", + json_prefix + "maxwell_test.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the quality of this image?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for image_path, json_ in zip(image_paths, jsons): + with open(json_) as f: + iqadata = json.load(f) + try: + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + filename = llddata["img_path"] + llddata["logits"] = defaultdict(float) + + image = load_video(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image] + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + + if True: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1] + + for tok, id_ in zip(toks, ids_): + llddata["logits"][tok] += output_logits.mean(0)[id_].item() + # print(llddata) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf: + json.dump(llddata, wf) + except: + continue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/q-align-image") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0695d696cdad4b277aab3a753dc4b9dce89af1 --- /dev/null +++ b/q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py @@ -0,0 +1,156 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_image(image_file): + if image_file.startswith('http://') or image_file.startswith('https://'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_path = "playground/data/" + + + json_prefix = "playground/data/test_jsons/" + jsons = [ + json_prefix + "test_imagerewarddb.json", + json_prefix + "test_koniq.json", + json_prefix + "test_spaq.json", + json_prefix + "test_kadid.json", + json_prefix + "livec.json", + json_prefix + "agi.json", + json_prefix + "live.json", + json_prefix + "csiq.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for json_ in jsons: + with open(json_) as f: + iqadata = json.load(f) + + image_tensors = [] + batch_data = [] + + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + if True: + try: + filename = llddata["image"] + except: + filename = llddata["img_path"] + llddata["logits"] = defaultdict(float) + + image = load_image(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + image_tensors.append(image_tensor) + batch_data.append(llddata) + + if i % 8 == 7 or i == len(iqadata) - 1: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(len(image_tensors), 1), + images=torch.cat(image_tensors, 0))["logits"][:,-1] + + for j, xllddata in enumerate(batch_data): + for tok, id_ in zip(toks, ids_): + xllddata["logits"][tok] += output_logits[j,id_].item() + # print(llddata) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf: + json.dump(xllddata, wf) + + image_tensors = [] + batch_data = [] + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..4d68f83df09a0f8696a49bede9844d7e7d6465dd --- /dev/null +++ b/q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py @@ -0,0 +1,155 @@ +from PIL import Image + +import torch.nn as nn +import torch + +from typing import List + +from q_align.model.builder import load_pretrained_model + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +class QAlignScorer(nn.Module): + def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None): + super().__init__() + if model is None: + tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is" + + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device) + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def expand2square(self, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + def forward(self, image: List[Image.Image]): + image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image] + with torch.inference_mode(): + image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device) + output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1, self.preferential_ids_] + + return torch.softmax(output_logits, -1) #@ self.weight_tensor + + +class QAlignAestheticScorer(nn.Module): + def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None): + super().__init__() + if model is None: + tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is" + + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device) + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def expand2square(self, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + def forward(self, image: List[Image.Image]): + image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image] + with torch.inference_mode(): + image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device) + output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1, self.preferential_ids_] + + return torch.softmax(output_logits, -1) #@ self.weight_tensor + +class QAlignVideoScorer(nn.Module): + def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None): + super().__init__() + if model is None: + tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is" + + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device) + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def expand2square(self, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + def forward(self, video: List[List[Image.Image]]): + video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video] + with torch.inference_mode(): + video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video] + output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1), + images=video_tensors)["logits"][:,-1, self.preferential_ids_] + return torch.softmax(output_logits, -1) #@ self.weight_tensor + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg") + parser.add_argument("--aesthetic", action="store_true") + parser.add_argument("--video", action="store_true") + args = parser.parse_args() + + if args.video: + scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device) + print(scorer([load_video(args.img_path)]).tolist()) + else: + scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device) + print(scorer([Image.open(args.img_path)]).tolist()) + diff --git a/q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb69fea96cf97ed9fd41554460026f0caa34614 --- /dev/null +++ b/q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py @@ -0,0 +1,167 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + + +from scipy.stats import spearmanr, pearsonr + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + +def wa5(logits): + import numpy as np + logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]]) + probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) + return np.inner(probs, np.array([1,0.75,0.5,0.25,0.])) + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_paths = [ + #"playground/data/", + #"playground/data/", + "playground/data/KoNViD_1k_videos/", + "playground/data/maxwell/", + + ] + + json_prefix = "playground/data/test_jsons/" + jsons = [ + #json_prefix + "test_lsvq.json", + #json_prefix + "test_lsvq_1080p.json", + json_prefix + "konvid.json", + json_prefix + "maxwell_test.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the quality of this video?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the video is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for image_path, json_ in zip(image_paths, jsons): + with open(json_) as f: + iqadata = json.load(f) + prs, gts = [], [] + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + try: + try: + filename = llddata["img_path"] + except: + filename = llddata["image"] + llddata["logits"] = defaultdict(float) + + image = load_video(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image] + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + if True: + with torch.inference_mode(): + output_logits = model(input_ids, + images=[image_tensor])["logits"][:,-1] + for tok, id_ in zip(toks, ids_): + llddata["logits"][tok] += output_logits.mean(0)[id_].item() + llddata["score"] = wa5(llddata["logits"]) + # print(llddata) + prs.append(llddata["score"]) + gts.append(llddata["gt_score"]) + # print(llddata) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf: + json.dump(llddata, wf) + + if i > 0 and i % 200 == 0: + print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0]) + except: + continue + print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/__pycache__/scorer.cpython-311.pyc b/q_align/evaluate/__pycache__/scorer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b421b86aed039452e511142014200bb403f9e659 Binary files /dev/null and b/q_align/evaluate/__pycache__/scorer.cpython-311.pyc differ diff --git a/q_align/evaluate/eval.py b/q_align/evaluate/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..008b7b07abfec6c41a1162a582dae2ccdae5f597 --- /dev/null +++ b/q_align/evaluate/eval.py @@ -0,0 +1,138 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_image(image_file): + if image_file.startswith('http://') or image_file.startswith('https://'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_path = "playground/data/" + + json_prefix = "playground/data/labels/mos_simple/" + jsons = [ + json_prefix + "test_flive.json", + json_prefix + "combined/kadid_ref.json", + json_prefix + "combined/livec.json", + json_prefix + "test_koniq.json", + json_prefix + "test_spaq.json", + json_prefix + "combined/agi.json", + json_prefix + "combined/kadid.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the quality of this image?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() + + for json_ in jsons: + with open(json_) as f: + iqadata = json.load(f) + + image_tensors = [] + batch_data = [] + + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + #print(f"Evaluating image {i}") + #print(prompt) + filename = llddata["image"] + llddata["logits"] = defaultdict(float) + + image = load_image(image_path + filename) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() + + image_tensors.append(image_tensor) + batch_data.append(llddata) + + if i % 8 == 7 or i == len(iqadata) - 1: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(len(image_tensors), 1), + images=torch.cat(image_tensors, 0))["logits"][:,-1] + + for j, xllddata in enumerate(batch_data): + for tok, id_ in zip(toks, ids_): + xllddata["logits"][tok] += output_logits[j,id_].item() + # print(llddata) + json_ = json_.replace("combined/", "combined-") + # print(f"results/mix-mplug-owl-2-boost_iqa_wu_v2/{json_.split('/')[-1]}") + with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf: + json.dump(xllddata, wf) + + image_tensors = [] + batch_data = [] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/q-align-koniq-spaq-v0") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/iaa_eval.py b/q_align/evaluate/iaa_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..64f17e29dacae2c145c0b10d1fe0dcdc27b8f6b9 --- /dev/null +++ b/q_align/evaluate/iaa_eval.py @@ -0,0 +1,164 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +from scipy.stats import spearmanr, pearsonr + + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + +def wa5(logits): + import numpy as np + logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]]) + probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) + return np.inner(probs, np.array([1,0.75,0.5,0.25,0.])) + + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_image(image_file): + if image_file.startswith('http://') or image_file.startswith('https://'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_path = "playground/data/" + + + json_prefix = "playground/data/test_jsons/" + jsons = [ + json_prefix + "test_ava.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the aesthetics of this image?" + + conv = conv_templates[conv_mode].copy() + inp = DEFAULT_IMAGE_TOKEN + inp + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The aesthetics of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for json_ in jsons: + with open(json_) as f: + iqadata = json.load(f) + + image_tensors = [] + batch_data = [] + prs, gts = [], [] + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + filename = llddata["image"] + llddata["logits"] = defaultdict(float) + + + + image = load_image(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + image_tensors.append(image_tensor) + batch_data.append(llddata) + + if i % 8 == 7 or i == len(iqadata) - 1: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(len(image_tensors), 1), + images=torch.cat(image_tensors, 0))["logits"][:,-1] + + for j, xllddata in enumerate(batch_data): + for tok, id_ in zip(toks, ids_): + xllddata["logits"][tok] += output_logits[j,id_].item() + xllddata["score"] = wa5(xllddata["logits"]) + # print(llddata) + prs.append(xllddata["score"]) + gts.append(xllddata["gt_score"]) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf: + json.dump(xllddata, wf) + + image_tensors = [] + batch_data = [] + + #if i > 0 and i % 200 == 0: + # print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0]) + print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/iqa4vqa_eval.py b/q_align/evaluate/iqa4vqa_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..1aaf1a7ac9ec4ee21a48ee96bfc92fe742559c24 --- /dev/null +++ b/q_align/evaluate/iqa4vqa_eval.py @@ -0,0 +1,150 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +from decord import VideoReader + + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_video(video_file): + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_paths = [ + "playground/data/", + "playground/data/", + "playground/data/KoNViD_1k_videos/", + "playground/data/maxwell/", + ] + + json_prefix = "playground/data/test_jsons/" + jsons = [ + json_prefix + "test_lsvq.json", + json_prefix + "test_lsvq_1080p.json", + json_prefix + "konvid.json", + json_prefix + "maxwell_test.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the quality of this image?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for image_path, json_ in zip(image_paths, jsons): + with open(json_) as f: + iqadata = json.load(f) + try: + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + filename = llddata["img_path"] + llddata["logits"] = defaultdict(float) + + image = load_video(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image] + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + + if True: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1] + + for tok, id_ in zip(toks, ids_): + llddata["logits"][tok] += output_logits.mean(0)[id_].item() + # print(llddata) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf: + json.dump(llddata, wf) + except: + continue + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/q-align-image") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/iqa_eval.py b/q_align/evaluate/iqa_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0695d696cdad4b277aab3a753dc4b9dce89af1 --- /dev/null +++ b/q_align/evaluate/iqa_eval.py @@ -0,0 +1,156 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_image(image_file): + if image_file.startswith('http://') or image_file.startswith('https://'): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_file).convert('RGB') + return image + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_path = "playground/data/" + + + json_prefix = "playground/data/test_jsons/" + jsons = [ + json_prefix + "test_imagerewarddb.json", + json_prefix + "test_koniq.json", + json_prefix + "test_spaq.json", + json_prefix + "test_kadid.json", + json_prefix + "livec.json", + json_prefix + "agi.json", + json_prefix + "live.json", + json_prefix + "csiq.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the image is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for json_ in jsons: + with open(json_) as f: + iqadata = json.load(f) + + image_tensors = [] + batch_data = [] + + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + if True: + try: + filename = llddata["image"] + except: + filename = llddata["img_path"] + llddata["logits"] = defaultdict(float) + + image = load_image(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + image_tensors.append(image_tensor) + batch_data.append(llddata) + + if i % 8 == 7 or i == len(iqadata) - 1: + with torch.inference_mode(): + output_logits = model(input_ids.repeat(len(image_tensors), 1), + images=torch.cat(image_tensors, 0))["logits"][:,-1] + + for j, xllddata in enumerate(batch_data): + for tok, id_ in zip(toks, ids_): + xllddata["logits"][tok] += output_logits[j,id_].item() + # print(llddata) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf: + json.dump(xllddata, wf) + + image_tensors = [] + batch_data = [] + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/evaluate/scorer.py b/q_align/evaluate/scorer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d68f83df09a0f8696a49bede9844d7e7d6465dd --- /dev/null +++ b/q_align/evaluate/scorer.py @@ -0,0 +1,155 @@ +from PIL import Image + +import torch.nn as nn +import torch + +from typing import List + +from q_align.model.builder import load_pretrained_model + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +class QAlignScorer(nn.Module): + def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None): + super().__init__() + if model is None: + tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is" + + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device) + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def expand2square(self, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + def forward(self, image: List[Image.Image]): + image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image] + with torch.inference_mode(): + image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device) + output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1, self.preferential_ids_] + + return torch.softmax(output_logits, -1) #@ self.weight_tensor + + +class QAlignAestheticScorer(nn.Module): + def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None): + super().__init__() + if model is None: + tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is" + + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device) + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def expand2square(self, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + def forward(self, image: List[Image.Image]): + image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image] + with torch.inference_mode(): + image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device) + output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1), + images=image_tensor)["logits"][:,-1, self.preferential_ids_] + + return torch.softmax(output_logits, -1) #@ self.weight_tensor + +class QAlignVideoScorer(nn.Module): + def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None): + super().__init__() + if model is None: + tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device) + prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is" + + self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]] + self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device) + + self.tokenizer = tokenizer + self.model = model + self.image_processor = image_processor + self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) + + def expand2square(self, pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + def forward(self, video: List[List[Image.Image]]): + video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video] + with torch.inference_mode(): + video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video] + output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1), + images=video_tensors)["logits"][:,-1, self.preferential_ids_] + return torch.softmax(output_logits, -1) #@ self.weight_tensor + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg") + parser.add_argument("--aesthetic", action="store_true") + parser.add_argument("--video", action="store_true") + args = parser.parse_args() + + if args.video: + scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device) + print(scorer([load_video(args.img_path)]).tolist()) + else: + scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device) + print(scorer([Image.open(args.img_path)]).tolist()) + diff --git a/q_align/evaluate/vqa_eval.py b/q_align/evaluate/vqa_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..7cb69fea96cf97ed9fd41554460026f0caa34614 --- /dev/null +++ b/q_align/evaluate/vqa_eval.py @@ -0,0 +1,167 @@ +import argparse +import torch + +from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN +from q_align.conversation import conv_templates, SeparatorStyle +from q_align.model.builder import load_pretrained_model +from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria + +from PIL import Image + +import requests +from PIL import Image +from io import BytesIO +from transformers import TextStreamer + + +from scipy.stats import spearmanr, pearsonr + +import json +from tqdm import tqdm +from collections import defaultdict + +import os + +def wa5(logits): + import numpy as np + logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]]) + probs = np.exp(logprobs) / np.sum(np.exp(logprobs)) + return np.inner(probs, np.array([1,0.75,0.5,0.25,0.])) + + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + + +def main(args): + # Model + disable_torch_init() + + model_name = get_model_name_from_path(args.model_path) + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) + + + import json + + + image_paths = [ + #"playground/data/", + #"playground/data/", + "playground/data/KoNViD_1k_videos/", + "playground/data/maxwell/", + + ] + + json_prefix = "playground/data/test_jsons/" + jsons = [ + #json_prefix + "test_lsvq.json", + #json_prefix + "test_lsvq_1080p.json", + json_prefix + "konvid.json", + json_prefix + "maxwell_test.json", + ] + + os.makedirs(f"results/{args.model_path}/", exist_ok=True) + + + conv_mode = "mplug_owl2" + + inp = "How would you rate the quality of this video?" + + conv = conv_templates[conv_mode].copy() + inp = inp + "\n" + DEFAULT_IMAGE_TOKEN + conv.append_message(conv.roles[0], inp) + image = None + + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + " The quality of the video is" + + toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"] + print(toks) + ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]] + print(ids_) + + input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device) + + for image_path, json_ in zip(image_paths, jsons): + with open(json_) as f: + iqadata = json.load(f) + prs, gts = [], [] + for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))): + try: + try: + filename = llddata["img_path"] + except: + filename = llddata["image"] + llddata["logits"] = defaultdict(float) + + image = load_video(image_path + filename) + def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image] + image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device) + + if True: + with torch.inference_mode(): + output_logits = model(input_ids, + images=[image_tensor])["logits"][:,-1] + for tok, id_ in zip(toks, ids_): + llddata["logits"][tok] += output_logits.mean(0)[id_].item() + llddata["score"] = wa5(llddata["logits"]) + # print(llddata) + prs.append(llddata["score"]) + gts.append(llddata["gt_score"]) + # print(llddata) + json_ = json_.replace("combined/", "combined-") + with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf: + json.dump(llddata, wf) + + if i > 0 and i % 200 == 0: + print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0]) + except: + continue + print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, default="q-future/one-align") + parser.add_argument("--model-base", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--conv-mode", type=str, default=None) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--load-8bit", action="store_true") + parser.add_argument("--load-4bit", action="store_true") + parser.add_argument("--debug", action="store_true") + parser.add_argument("--image-aspect-ratio", type=str, default='pad') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/q_align/mm_utils.py b/q_align/mm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1ee281029576bf0f23384ebb79a34e2f34d578 --- /dev/null +++ b/q_align/mm_utils.py @@ -0,0 +1,112 @@ +from PIL import Image +from io import BytesIO +import base64 + +import torch +from transformers import StoppingCriteria +from q_align.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN +from icecream import ic + + +def load_image_from_base64(image): + return Image.open(BytesIO(base64.b64decode(image))) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def process_images(images, image_processor, model_cfg=None): + if model_cfg is not None: + image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) + else: + image_aspect_ratio = 'resize' + new_images = [] + if image_aspect_ratio == 'pad': + for image in images: + image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + elif image_aspect_ratio == 'resize': + for image in images: + max_edge = max(image.size) + image = image.resize((max_edge, max_edge)) + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + new_images.append(image) + else: + return image_processor(images, return_tensors='pt')['pixel_values'] + if all(x.shape == new_images[0].shape for x in new_images): + new_images = torch.stack(new_images, dim=0) + return new_images + + +def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): + prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + if return_tensors is not None: + if return_tensors == 'pt': + return torch.tensor(input_ids, dtype=torch.long) + raise ValueError(f'Unsupported tensor type: {return_tensors}') + return input_ids + + +def get_model_name_from_path(model_path): + model_path = model_path.strip("/") + model_paths = model_path.split("/") + if model_paths[-1].startswith('checkpoint-'): + return model_paths[-2] + "_" + model_paths[-1] + else: + return model_paths[-1] + + + + +class KeywordsStoppingCriteria(StoppingCriteria): + def __init__(self, keywords, tokenizer, input_ids): + self.keywords = keywords + self.keyword_ids = [] + self.max_keyword_len = 0 + for keyword in keywords: + cur_keyword_ids = tokenizer(keyword).input_ids + if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: + cur_keyword_ids = cur_keyword_ids[1:] + if len(cur_keyword_ids) > self.max_keyword_len: + self.max_keyword_len = len(cur_keyword_ids) + self.keyword_ids.append(torch.tensor(cur_keyword_ids)) + self.tokenizer = tokenizer + self.start_len = input_ids.shape[1] + + def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO + offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) + self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] + for keyword_id in self.keyword_ids: + if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): + return True + outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] + for keyword in self.keywords: + if keyword in outputs: + return True + return False \ No newline at end of file diff --git a/q_align/model/__init__.py b/q_align/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6f0775a0abb2c3e220343a4feb05c70c2c7779 --- /dev/null +++ b/q_align/model/__init__.py @@ -0,0 +1,2 @@ +from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM +from .configuration_mplug_owl2 import MPLUGOwl2Config \ No newline at end of file diff --git a/q_align/model/__pycache__/__init__.cpython-310.pyc b/q_align/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2385b4122537f406b59b779bbee7db97400c7970 Binary files /dev/null and b/q_align/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/__init__.cpython-311.pyc b/q_align/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a0f4f90f4ca15acc6fce979cdde7c09a3854252 Binary files /dev/null and b/q_align/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/__init__.cpython-39.pyc b/q_align/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58122c87b9917b19fc6d37f5b4d9d6e56165fe5f Binary files /dev/null and b/q_align/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/q_align/model/__pycache__/builder.cpython-310.pyc b/q_align/model/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..105a98b56eb0ec9844c7e91ebe97a3859c6da159 Binary files /dev/null and b/q_align/model/__pycache__/builder.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/builder.cpython-311.pyc b/q_align/model/__pycache__/builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..433a8f78c22f1bb8cbed44a8212c7a135b6984cf Binary files /dev/null and b/q_align/model/__pycache__/builder.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/builder.cpython-39.pyc b/q_align/model/__pycache__/builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29c5b7064f5ff98c67a0ce57a3a00f334451d0bb Binary files /dev/null and b/q_align/model/__pycache__/builder.cpython-39.pyc differ diff --git a/q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..635940c1e254c29f31bf99267093652e79be3cde Binary files /dev/null and b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25719ada8a0c968af25010fbd5e5210e41970fe1 Binary files /dev/null and b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5900c0e16a8406f291054065b730695ad319ed2 Binary files /dev/null and b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc differ diff --git a/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5696d2b89bfcbca12fb90f41d4659948835bc8e5 Binary files /dev/null and b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f214501d903dd1fce9af34f44eb926aa7f9c151 Binary files /dev/null and b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7205da42d472fbe7c21978b1a00eb1c80d48c13e Binary files /dev/null and b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc differ diff --git a/q_align/model/__pycache__/modeling_llama2.cpython-310.pyc b/q_align/model/__pycache__/modeling_llama2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf3b452eb120657e9fdc922c2d9af36c4779a486 Binary files /dev/null and b/q_align/model/__pycache__/modeling_llama2.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/modeling_llama2.cpython-311.pyc b/q_align/model/__pycache__/modeling_llama2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12c41d1db6becf19b2c92e71fcbcc644df70ff40 Binary files /dev/null and b/q_align/model/__pycache__/modeling_llama2.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/modeling_llama2.cpython-39.pyc b/q_align/model/__pycache__/modeling_llama2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4d5c16c742f3034faa01d7c025693d9145f074 Binary files /dev/null and b/q_align/model/__pycache__/modeling_llama2.cpython-39.pyc differ diff --git a/q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57e4fcfc33a50f58b7c7f9728799eed476f64acc Binary files /dev/null and b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8a41ee61a9e40ffc92f2390b655242ad5224ff Binary files /dev/null and b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7f37ea2a8e9a32491babc6fa09766cd462475e4 Binary files /dev/null and b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc differ diff --git a/q_align/model/__pycache__/visual_encoder.cpython-310.pyc b/q_align/model/__pycache__/visual_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5afa32d185606d4e67fa234c37fa7d3dd1a95ea9 Binary files /dev/null and b/q_align/model/__pycache__/visual_encoder.cpython-310.pyc differ diff --git a/q_align/model/__pycache__/visual_encoder.cpython-311.pyc b/q_align/model/__pycache__/visual_encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc26e07486aa49227b346f1940189272ff91b0d7 Binary files /dev/null and b/q_align/model/__pycache__/visual_encoder.cpython-311.pyc differ diff --git a/q_align/model/__pycache__/visual_encoder.cpython-39.pyc b/q_align/model/__pycache__/visual_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32fda8db614444b38bd7fb71951950478fd19487 Binary files /dev/null and b/q_align/model/__pycache__/visual_encoder.cpython-39.pyc differ diff --git a/q_align/model/builder.py b/q_align/model/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..5da53f5ff3d39265636f154923173c49186e7fa4 --- /dev/null +++ b/q_align/model/builder.py @@ -0,0 +1,118 @@ +# Copyright 2023 Haotian Liu +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import warnings +import shutil + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig +from transformers.models.clip.image_processing_clip import CLIPImageProcessor +import torch +from q_align.model import * +from icecream import ic +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"): + kwargs = {"device_map": device_map} + + if device != "cuda": + kwargs['device_map'] = {"": device} + + if load_8bit: + kwargs['load_in_8bit'] = True + elif load_4bit: + kwargs['load_in_4bit'] = True + kwargs['quantization_config'] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4' + ) + else: + kwargs['torch_dtype'] = torch.float16 + if 'mplug_owl2' in model_name.lower(): + # Load LLaVA model + if 'lora' in model_name.lower() and model_base is None: + warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') + if 'lora' in model_name.lower() and model_base is not None: + lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + print('Loading mPLUG-Owl2 from base model...') + model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features + if model.lm_head.weight.shape[0] != token_num: + model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) + + print('Loading additional mPLUG-Owl2 weights...') + if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): + non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') + else: + # this is probably from HF Hub + from huggingface_hub import hf_hub_download + def load_from_hf(repo_id, filename, subfolder=None): + cache_file = hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder) + return torch.load(cache_file, map_location='cpu') + non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') + non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} + if any(k.startswith('model.model.') for k in non_lora_trainables): + non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} + model.load_state_dict(non_lora_trainables, strict=False) + + from peft import PeftModel + print('Loading LoRA weights...') + model = PeftModel.from_pretrained(model, model_path) + print('Merging LoRA weights...') + model = model.merge_and_unload() + print('Model is loaded...') + elif model_base is not None: + # this may be mm projector only + print('Loading mPLUG-Owl2 from base model...') + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + else: + # Load language model + if model_base is not None: + # PEFT model + from peft import PeftModel + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) + print(f"Loading LoRA weights from {model_path}") + model = PeftModel.from_pretrained(model, model_path) + print(f"Merging weights") + model = model.merge_and_unload() + print('Convert to FP16...') + model.to(torch.float16) + else: + use_fast = False + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) + + + vision_tower = model.get_model().vision_model + vision_tower.to(device=device, dtype=torch.float16) + image_processor = CLIPImageProcessor.from_pretrained(model_path) + + if hasattr(model.config, "max_sequence_length"): + context_len = model.config.max_sequence_length + else: + context_len = 2048 + + return tokenizer, model, image_processor, context_len \ No newline at end of file diff --git a/q_align/model/configuration_mplug_owl2.py b/q_align/model/configuration_mplug_owl2.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e31a61cc919a694ef62929c705d91143be63d1 --- /dev/null +++ b/q_align/model/configuration_mplug_owl2.py @@ -0,0 +1,332 @@ +# Copyright (c) Alibaba. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import copy +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES +from transformers.utils import logging +from transformers.models.auto import CONFIG_MAPPING + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + +class MplugOwlVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate + a + mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration defaults will yield a similar configuration to that of the mPLUG-Owl + [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 32): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + + + ```""" + + model_type = "mplug_owl_vision_model" + + def __init__( + self, + hidden_size=1024, + intermediate_size=4096, + projection_dim=768, + num_hidden_layers=24, + num_attention_heads=16, + num_channels=3, + image_size=448, + patch_size=14, + hidden_act="quick_gelu", + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + use_flash_attn=False, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.projection_dim = projection_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.use_flash_attn = use_flash_attn + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from MplugOwlConfig + if config_dict.get("model_type") == "mplug-owl": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class MplugOwlVisualAbstractorConfig(PretrainedConfig): + model_type = "mplug_owl_visual_abstract" + + def __init__( + self, + num_learnable_queries=64, + hidden_size=1024, + num_hidden_layers=6, + num_attention_heads=16, + intermediate_size=2816, + attention_probs_dropout_prob=0., + initializer_range=0.02, + layer_norm_eps=1e-6, + encoder_hidden_size=1024, + grid_size=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_learnable_queries = num_learnable_queries + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.encoder_hidden_size = encoder_hidden_size + self.grid_size = grid_size if grid_size else 32 + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the visual_abstractor config dict if we are loading from MplugOwlConfig + if config_dict.get("model_type") == "mplug-owl": + config_dict = config_dict["abstractor_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + + +DEFAULT_VISUAL_CONFIG = { + "visual_model": MplugOwlVisionConfig().to_dict(), + "visual_abstractor": MplugOwlVisualAbstractorConfig().to_dict() +} + +class MPLUGOwl2Config(LlamaConfig): + model_type = "mplug_owl2" + def __init__(self, visual_config=None, **kwargs): + if visual_config is None: + self.visual_config = DEFAULT_VISUAL_CONFIG + else: + self.visual_config = visual_config + + super().__init__( + **kwargs, + ) + +if __name__ == "__main__": + print(MplugOwlVisionConfig().to_dict()) \ No newline at end of file diff --git a/q_align/model/convert_mplug_owl2_weight_to_hf.py b/q_align/model/convert_mplug_owl2_weight_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..8288a9a6a9b5d7a1a4ec58af6a53b14d4b266580 --- /dev/null +++ b/q_align/model/convert_mplug_owl2_weight_to_hf.py @@ -0,0 +1,395 @@ +# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import math +import os +import shutil +import warnings + +import torch + +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig +from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \ + --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80} +llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64} +llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016, + 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that +llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192} + + +def compute_intermediate_size(n): + return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, + input_base_path, + model_size, + num_input_shards=1, + num_output_shards=2, + skip_permute=True, + norm_eps=1e-05): + # if os.path.exists(model_path): + # shutil.rmtree(model_path) + os.makedirs(model_path, exist_ok=True) + # tmp_model_path = os.path.join(model_path, "tmp") + tmp_model_path = model_path + os.makedirs(tmp_model_path, exist_ok=True) + + num_shards = num_input_shards + n_layers = llama_s2layer[model_size] + n_heads = llama_s2heads[model_size] + n_heads_per_shard = n_heads // num_shards + n_dense = llama_s2dense[model_size] + n_hidden = llama_s2hidden[model_size] + hidden_per_head = n_hidden // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head)) + + # permute for sliced rotary + def permute(w, skip_permute=skip_permute): + if skip_permute: + return w + return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden) + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if num_shards==1: + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt + if os.path.exists(os.path.join(input_base_path, 'release')): + filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt') + elif input_base_path.split('/')[-1].startswith('iter_'): + iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0')) + load_dir = '/'.join(input_base_path.split('/')[:-1]) + filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt') + if not os.path.exists(filename): + filename = filename.replace('model_optim_rng.pt', 'model_rng.pt') + else: + tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt') + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + iteration = 'iter_{:07d}'.format(int(metastring)) + filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt') + if not os.path.exists(filename): + filename = filename.replace('model_optim_rng.pt', 'model_rng.pt') + original_filename = filename + loaded = torch.load(filename, map_location="cpu")['model']['language_model'] + + else: + # Sharded + filenames = [] + for i in range(num_shards): + if os.path.exists(os.path.join(input_base_path, 'release')): + filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt') + else: + tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt') + with open(tracker_filename, 'r') as f: + metastring = f.read().strip() + iteration = 'iter_{:07d}'.format(int(metastring)) + filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt') + if not os.path.exists(filename): + filename = filename.replace('model_optim_rng.pt', 'model_rng.pt') + filenames.append(filename) + loaded = [ + torch.load(filenames[i], map_location="cpu")['model']['language_model'] + for i in range(num_shards) + ] + + print('Llama-Megatron Loaded!') + param_count = 0 + index_dict = {"weight_map": {}} + + print(f'Weighted Converting for {n_layers} layers...') + for layer_i in range(n_layers): + print(layer_i) + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + if num_shards == 1: + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"], + f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"], + f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"], + f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"], + f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"], + f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"], + f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"], + } + else: + raise NotImplemented +# else: +# # Sharded +# # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share +# # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is +# # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + +# state_dict = { +# f"model.layers.{layer_i}.input_layernorm.weight": loaded[0]['encoder'][ +# f"layers.{layer_i}.input_layernorm.multiway.0.weight" +# ].clone(), +# f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0]['encoder'][ +# f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight" +# ].clone(), +# } + +# wqs, wks, wvs, ffn_w1s, ffn_w3s = [], [], [], [], [] +# for shard_idx in range(num_shards): +# wqs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"]) +# wks.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"]) +# wvs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"]) + +# state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( +# torch.cat( +# [ +# wq.view(n_heads_per_shard, hidden_per_head, n_hidden) +# for wq in range(wqs) +# ], +# dim=0, +# ).reshape(n_hidden, n_hidden) +# ) +# state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( +# torch.cat( +# [ +# wk.view(n_heads_per_shard, hidden_per_head, n_hidden) +# for wk in range(wks) +# ], +# dim=0, +# ).reshape(n_hidden, n_hidden) +# ) +# state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( +# [ +# wv.view(n_heads_per_shard, hidden_per_head, n_hidden) +# for wv in range(wvs) +# ], +# dim=0, +# ).reshape(n_hidden, n_hidden) + +# state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( +# [loaded[i]['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"] for i in range(num_shards)], dim=1 +# ) +# state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( +# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"] for i in range(num_shards)], dim=0 +# ) +# state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( +# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"] for i in range(num_shards)], dim=1 +# ) +# state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( +# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"] for i in range(num_shards)], dim=0 +# ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f'Sharded file saved to {filename}') + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + if num_shards==1: + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'], + "model.norm.weight": loaded['encoder']['norm.weight'], + "lm_head.weight": loaded['encoder']['lm_head.weight'], + } + else: + state_dict = { + "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'], + "model.norm.weight": loaded[0]['encoder']['norm.weight'], + "lm_head.weight": loaded[0]['encoder']['lm_head.weight'], + } + + + loaded_all = torch.load(original_filename, map_location="cpu")['model'] + # Vision Part + state_dict.update({ + "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'], + "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'], + "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'], + "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'], + "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'], + "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'], + "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'], + }) + for v_layer_idx in range(24): + state_dict.update({ + f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'], + f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'], + f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'], + f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'], + f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'], + f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'], + f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'], + f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'], + f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'], + f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'], + f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'], + f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'], + }) + + # Abstractor Part + state_dict.update({ + "model.visual_abstractor.query_embeds": loaded_all['vision_abstractor']['learnable_queries'], + "model.visual_abstractor.visual_fc.bias": loaded_all['vision_abstractor']['visual_fc']['bias'], + "model.visual_abstractor.visual_fc.weight": loaded_all['vision_abstractor']['visual_fc']['weight'], + "model.visual_abstractor.vit_eos": loaded_all['vision_abstractor']['vit_eos'], + }) + for v_layer_idx in range(6): + state_dict.update({ + # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.k_pos_embed": + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.weight"], + # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.q_pos_embed": "pytorch_model-00004-of-00004.bin", + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.weight"], + + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.weight"], + + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.weight"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.bias"], + f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.weight"], + }) + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + config = MPLUGOwl2Config() + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + del loaded_all + gc.collect() + +def write_tokenizer(tokenizer_path, input_tokenizer_path): + # Initialize the tokenizer based on the `spm` model + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + tokenizer = tokenizer_class(input_tokenizer_path) + tokenizer.save_pretrained(tokenizer_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA_Megatron weights", + ) + parser.add_argument( + "--model_size", + type=int, + default=7, + choices=[7, 13, 30, 65, 70], + ) + parser.add_argument( + "--num_input_shards", + type=int, + default=1, + ) + parser.add_argument( + "--num_output_shards", + type=int, + default=1, + ) + parser.add_argument('--skip_permute', action='store_true') + + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + model_size=args.model_size, + num_input_shards=args.num_input_shards, + num_output_shards=args.num_output_shards, + skip_permute=args.skip_permute + ) + + +if __name__ == "__main__": + main() diff --git a/q_align/model/modeling_attn_mask_utils.py b/q_align/model/modeling_attn_mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c2583a2dd5a09b1119c849ca00f954198d078799 --- /dev/null +++ b/q_align/model/modeling_attn_mask_utils.py @@ -0,0 +1,247 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch + + +class AttentionMaskConverter: + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + if self.sliding_window is not None and self.sliding_window <= 0: + raise ValueError( + f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" + ) + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: Optional[int] = None, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + if key_value_length is None: + raise ValueError( + "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." + ) + + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + @staticmethod + def _make_causal_mask( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + @staticmethod + def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _prepare_4d_causal_attention_mask( + attention_mask: Optional[torch.Tensor], + input_shape: Union[torch.Size, Tuple, List], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + attention_mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + inputs_embeds (`torch.Tensor`): + The embedded inputs as a torch Tensor. + past_key_values_length (`int`): + The length of the key value cache. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = input_shape[-1] + past_key_values_length + + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = attn_mask_converter.to_4d( + attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + + return attention_mask + + +def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)` + + Args: + mask (`torch.Tensor` or `None`): + A 2D attention mask of shape `(batch_size, key_value_length)` + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + tgt_len (`int`): + The target length or query length the created mask shall have. + """ + return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _create_4d_causal_attention_mask( + input_shape: Union[torch.Size, Tuple, List], + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` + + Args: + input_shape (`tuple(int)` or `list(int)` or `torch.Size`): + The input shape should be a tuple that defines `(batch_size, query_length)`. + dtype (`torch.dtype`): + The torch dtype the created mask shall have. + device (`int`): + The torch device the created mask shall have. + sliding_window (`int`, *optional*): + If the model uses windowed attention, a sliding window should be passed. + """ + attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) + + key_value_length = past_key_values_length + input_shape[-1] + attention_mask = attn_mask_converter.to_causal_4d( + input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device + ) + + return attention_mask \ No newline at end of file diff --git a/q_align/model/modeling_llama2.py b/q_align/model/modeling_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..2dbc5738e4b456bfee4035c3cc1bba7576118e91 --- /dev/null +++ b/q_align/model/modeling_llama2.py @@ -0,0 +1,486 @@ +import math +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +import transformers +from transformers.models.llama.modeling_llama import * +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from .configuration_mplug_owl2 import LlamaConfig + +class MultiwayNetwork(nn.Module): + + def __init__(self, module_provider, num_multiway=2): + super(MultiwayNetwork, self).__init__() + + self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)]) + + def forward(self, hidden_states, multiway_indices): + + if len(self.multiway) == 1: + return self.multiway[0](hidden_states) + + output_hidden_states = torch.empty_like(hidden_states) + + for idx, subway in enumerate(self.multiway): + local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True) + hidden = hidden_states[local_indices].unsqueeze(1).contiguous() + if hidden.numel(): + output = subway(hidden) + if isinstance(output, tuple): + output = output[0] + output = output.squeeze(1) + output_hidden_states[local_indices] = output + + return output_hidden_states.contiguous() + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = MultiwayNetwork(module_provider=partial( + nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + ) + self.v_proj = MultiwayNetwork(module_provider=partial( + nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, ) + key_states = self.k_proj(hidden_states, modality_indicators) + value_states = self.v_proj(hidden_states, modality_indicators) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP(config) + self.input_layernorm = MultiwayNetwork(module_provider=partial( + LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps + )) + self.post_attention_layernorm = MultiwayNetwork(module_provider=partial( + LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps + )) + + def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states, modality_indicators) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def model_forward( + self, + input_ids: torch.LongTensor = None, + modality_indicators: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + modality_indicators, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def causal_model_forward( + self, + input_ids: torch.LongTensor = None, + modality_indicators: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +def replace_llama_modality_adaptive(): + transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer + transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward + transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward + + +if __name__ == "__main__": + replace_llama_modality_adaptive() + config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/') + model = transformers.LlamaForCausalLM(config) + print(model) \ No newline at end of file diff --git a/q_align/model/modeling_mplug_owl2.py b/q_align/model/modeling_mplug_owl2.py new file mode 100644 index 0000000000000000000000000000000000000000..e7154a67c056cfa4cdd5033b5ce46c5b62bf3cbb --- /dev/null +++ b/q_align/model/modeling_mplug_owl2.py @@ -0,0 +1,329 @@ +# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM +from transformers.modeling_outputs import CausalLMOutputWithPast + +from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig +from .visual_encoder import MplugOwlVisionModel, MplugOwlVisualAbstractorModel +from .modeling_llama2 import replace_llama_modality_adaptive +from q_align.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX +from icecream import ic + +class MPLUGOwl2MetaModel: + def __init__(self, config): + super(MPLUGOwl2MetaModel, self).__init__(config) + self.vision_model = MplugOwlVisionModel( + MplugOwlVisionConfig(**config.visual_config["visual_model"]) + ) + self.visual_abstractor = MplugOwlVisualAbstractorModel( + MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]), config.hidden_size + ) + + def get_vision_tower(self): + vision_model = getattr(self, 'vision_model', None) + if type(vision_model) is list: + vision_model = vision_model[0] + return vision_model + + def get_visual_abstractor(self): + visual_abstractor = getattr(self, 'visual_abstractor', None) + if type(visual_abstractor) is list: + visual_abstractor = visual_abstractor[0] + return visual_abstractor + + +class MPLUGOwl2MetaForCausalLM(ABC): + @abstractmethod + def get_model(self): + pass + + def encode_images(self, images): + image_features = self.get_model().vision_model(images).last_hidden_state + image_features = self.get_model().visual_abstractor(encoder_hidden_states=image_features).last_hidden_state + return image_features + + def prepare_inputs_labels_for_multimodal( + self, input_ids, attention_mask, past_key_values, labels, images + ): + if images is None or input_ids.shape[1] == 1: + if past_key_values is not None and images is not None and input_ids.shape[1] == 1: + attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) + multiway_indices = torch.zeros_like(input_ids).long().to(self.device) + return input_ids, multiway_indices, attention_mask, past_key_values, None, labels + + if type(images) is list or images.ndim == 5: + concat_images = torch.cat([image for image in images], dim=0) + image_features = self.encode_images(concat_images) + split_sizes = [image.shape[0] for image in images] + image_features = torch.split(image_features, split_sizes, dim=0) + image_features = [x.flatten(0, 1) for x in image_features] + else: + image_features = self.encode_images(images) + + new_input_embeds = [] + new_modality_indicators = [] + new_labels = [] if labels is not None else None + cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): + if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + # FIXME: this is a hacky fix, for deepspeed zero3 to work + half_len = cur_input_ids.shape[0] // 2 + cur_image_features = image_features[cur_image_idx] + cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) + cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) + cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0) + new_input_embeds.append(cur_input_embeds) + + cur_modality_indicators = torch.zeros(len(cur_input_embeds)).long().to(self.device) + new_modality_indicators.append(cur_modality_indicators) + if labels is not None: + new_labels.append(labels[batch_idx]) + cur_image_idx += 1 + continue + image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] + cur_new_input_embeds = [] + cur_modality_indicators = [] + if labels is not None: + cur_labels = labels[batch_idx] + cur_new_labels = [] + assert cur_labels.shape == cur_input_ids.shape + while image_token_indices.numel() > 0: + cur_image_features = image_features[cur_image_idx] + image_token_start = image_token_indices[0] + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) + cur_new_input_embeds.append(cur_image_features) + + # Add modality indicator + assert image_token_start == len(cur_input_ids[:image_token_start]) + cur_modality_indicators.append(torch.zeros(len(cur_input_ids[:image_token_start])).long()) + cur_modality_indicators.append(torch.ones(len(cur_image_features)).long()) + + if labels is not None: + cur_new_labels.append(cur_labels[:image_token_start]) + cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) + cur_labels = cur_labels[image_token_start+1:] + cur_image_idx += 1 + cur_input_ids = cur_input_ids[image_token_start+1:] + image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] + if cur_input_ids.numel() > 0: + cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) + cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long()) + if labels is not None: + cur_new_labels.append(cur_labels) + cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] + cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) + new_input_embeds.append(cur_new_input_embeds) + + # Modality + cur_modality_indicators = [x.to(device=self.device) for x in cur_modality_indicators] + cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0) + new_modality_indicators.append(cur_modality_indicators) + + + if labels is not None: + cur_new_labels = torch.cat(cur_new_labels, dim=0) + new_labels.append(cur_new_labels) + + if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): + max_len = max(x.shape[0] for x in new_input_embeds) + + # Embedding + new_input_embeds_align = [] + for cur_new_embed in new_input_embeds: + cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) + new_input_embeds_align.append(cur_new_embed) + new_input_embeds = torch.stack(new_input_embeds_align, dim=0) + + # Modality + new_modality_indicators_align = [] + for cur_modality_indicator in new_modality_indicators: + cur_new_embed = torch.cat((cur_modality_indicator, torch.zeros(max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device)), dim=0) + new_modality_indicators_align.append(cur_new_embed) + new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0) + + # Label + if labels is not None: + new_labels_align = [] + _new_labels = new_labels + for cur_new_label in new_labels: + cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) + new_labels_align.append(cur_new_label) + new_labels = torch.stack(new_labels_align, dim=0) + + # Attention Mask + if attention_mask is not None: + new_attention_mask = [] + for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): + new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) + new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) + cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) + new_attention_mask.append(cur_new_attention_mask) + attention_mask = torch.stack(new_attention_mask, dim=0) + assert attention_mask.shape == new_labels.shape + else: + new_input_embeds = torch.stack(new_input_embeds, dim=0) + new_modality_indicators = torch.stack(new_modality_indicators, dim=0) + if labels is not None: + new_labels = torch.stack(new_labels, dim=0) + + if attention_mask is not None: + new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) + attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) + assert attention_mask.shape == new_input_embeds.shape[:2] + return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels + + + +class MPLUGOwl2LlamaModel(MPLUGOwl2MetaModel, LlamaModel): + config_class = MPLUGOwl2Config + + def __init__(self, config: MPLUGOwl2Config): + super(MPLUGOwl2LlamaModel, self).__init__(config) + + +class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM): + config_class = MPLUGOwl2Config + + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = MPLUGOwl2LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_model(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + # modality_indicators: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \ + self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + } + ) + return model_inputs + +AutoConfig.register("mplug_owl2", MPLUGOwl2Config) +AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM) + +replace_llama_modality_adaptive() + +if __name__ == "__main__": + config = MPLUGOwl2Config.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/') + from icecream import ic + # config = MPLUGOwl2Config() + model = MPLUGOwl2LlamaForCausalLM(config) + + images = torch.randn(2, 3, 448, 448) + input_ids = torch.cat([ + torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long() + ], dim=0).unsqueeze(0) + labels = input_ids.clone() + labels[labels < 0] = -100 + + # image_feature = model.encode_images(images) + # ic(image_feature.shape) + + output = model(images=images, input_ids=input_ids, labels=labels) + ic(output.loss) + ic(output.logits.shape) + + model.save_pretrained('/cpfs01/shared/public/test/tmp_owl') \ No newline at end of file diff --git a/q_align/model/utils.py b/q_align/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a47edb191f0f6587827cfab6a731ec106a523b0 --- /dev/null +++ b/q_align/model/utils.py @@ -0,0 +1,20 @@ +from transformers import AutoConfig + + +def auto_upgrade(config): + cfg = AutoConfig.from_pretrained(config) + if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type: + assert cfg.model_type == 'mplug_owl2' + print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") + print("You must upgrade the checkpoint to the new code base (this can be done automatically).") + confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") + if confirm.lower() in ["y", "yes"]: + print("Upgrading checkpoint...") + assert len(cfg.architectures) == 1 + setattr(cfg.__class__, "model_type", "mplug_owl2") + cfg.architectures[0] = 'LlavaLlamaForCausalLM' + cfg.save_pretrained(config) + print("Checkpoint upgraded.") + else: + print("Checkpoint upgrade aborted.") + exit(1) \ No newline at end of file diff --git a/q_align/model/visual_encoder.py b/q_align/model/visual_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..39f76207dd9a25f870c1cf35f7fbcba2bf342e0c --- /dev/null +++ b/q_align/model/visual_encoder.py @@ -0,0 +1,928 @@ +import math +from typing import Any, Optional, Tuple, Union + +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from icecream import ic + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + return F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + else: + return abs_pos + +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + + +class MplugOwlVisionEmbeddings(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size)) + + self.patch_embed = nn.Conv2d( + in_channels=3, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size)) + + self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + image_embeds = self.patch_embed(pixel_values) + image_embeds = image_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype) + embeddings = torch.cat([class_embeds, image_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype) + embeddings = self.pre_layernorm(embeddings) + return embeddings + + + +class MplugOwlVisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + if self.head_dim * self.num_heads != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = nn.Dropout(config.attention_dropout) + + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, seq_len, embed_dim = hidden_states.size() + + mixed_qkv = self.query_key_value(hidden_states) + + mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute( + 3, 0, 2, 1, 4 + ) # [3, b, np, sq, hn] + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) + # if self.config.use_flash_attn and flash_attn_func is not None: + if False: + # [b*sq, np, hn] + query_states = query_states.permute(0, 2, 1, 3).contiguous() + query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1) + + key_states = key_states.permute(0, 2, 1, 3).contiguous() + key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1) + + value_states = value_states.permute(0, 2, 1, 3).contiguous() + value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1) + + cu_seqlens = torch.arange( + 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device + ) + + context_layer = flash_attn_func( + query_states, + key_states, + value_states, + cu_seqlens, + cu_seqlens, + seq_len, + seq_len, + self.dropout if self.training else 0.0, + softmax_scale=self.scale, + causal=False, + return_attn_probs=False, + ) + # [b*sq, np, hn] => [b, sq, np, hn] + context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2)) + else: + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = torch.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.dense(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class MplugOwlMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = QuickGELU() + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MplugOwlVisionEncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MplugOwlVisionAttention(config) + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + self.mlp = MplugOwlMLP(config) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + residual + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class MplugOwlVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`MplugOwlVisionEncoderLayer`]. + + Args: + config (`MplugOwlVisionConfig`): + The corresponding vision configuration for the `MplugOwlEncoder`. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = True + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class MplugOwlVisionModel(PreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.hidden_size = config.hidden_size + + self.embeddings = MplugOwlVisionEmbeddings(config) + self.encoder = MplugOwlVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps) + + self.post_init() + + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +class MplugOwlVisualAbstractorMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + in_features = config.hidden_size + self.act = nn.SiLU() + + self.w1 = nn.Linear(in_features, config.intermediate_size) + self.w2 = nn.Linear(config.intermediate_size, in_features) + self.w3 = nn.Linear(in_features, config.intermediate_size) + self.ffn_ln = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.act(self.w1(hidden_states)) * self.w3(hidden_states) + hidden_states = self.ffn_ln(hidden_states) + hidden_states = self.w2(hidden_states) + return hidden_states + + +class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention heads (%d)" + % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size) + self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.save_attention = False + +# self.q_pos_embed = nn.Parameter( +# torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float() +# ).requires_grad_(False) +# grids = config.grid_size +# self.k_pos_embed = nn.Parameter( +# torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float() +# ).requires_grad_(False) + grids = config.grid_size + self.register_buffer( + 'q_pos_embed', + torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float() + ) + self.register_buffer( + 'k_pos_embed', + torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float() + ) + + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + + qk_pos_embed = torch.cat([self.q_pos_embed, self.k_pos_embed], dim = 0).unsqueeze(0).to(dtype=hidden_states.dtype) + + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states + qk_pos_embed)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + + mixed_query_layer = self.query(hidden_states + self.q_pos_embed.unsqueeze(0).to(dtype=hidden_states.dtype)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class MplugOwlVisualAbstractorCrossOutput(nn.Module): + def __init__(self, config): + super().__init__() + dim = config.hidden_size + self.out_proj = nn.Linear(dim, dim, bias=True) + self.norm2 = nn.LayerNorm(dim) + self.mlp = MplugOwlVisualAbstractorMLP(config) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + input_tensor = input_tensor + self.out_proj(hidden_states) + input_tensor = input_tensor + self.mlp(self.norm2(input_tensor)) + return input_tensor + + +class MplugOwlVisualAbstractorAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = MplugOwlVisualAbstractorMultiHeadAttention(config) + self.output = MplugOwlVisualAbstractorCrossOutput(config) + self.pruned_heads = set() + self.norm1 = nn.LayerNorm(config.hidden_size) + self.normk = nn.LayerNorm(config.hidden_size) + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.out_proj, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # HACK we apply norm on q and k + hidden_states = self.norm1(hidden_states) + encoder_hidden_states = self.normk(encoder_hidden_states) + encoder_hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + encoder_attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=-1) + self_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] + return outputs + + +class MplugOwlVisualAbstractorLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + + self.layer_idx = layer_idx + + self.crossattention = MplugOwlVisualAbstractorAttention(config) + self.has_cross_attention = True + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + ): + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + cross_attention_outputs = self.crossattention( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + + outputs = (query_attention_output,) + return outputs + + +class MplugOwlVisualAbstractorEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [MplugOwlVisualAbstractorLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = True + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layers[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + return BaseModelOutput( + last_hidden_state=hidden_states, + ) + + +class MplugOwlVisualAbstractorModel(PreTrainedModel): + def __init__(self, config, language_hidden_size): + super().__init__(config) + self.config = config + + self.encoder = MplugOwlVisualAbstractorEncoder(config) + self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size) + self.query_embeds = torch.nn.Parameter(torch.randn(1, config.num_learnable_queries, config.hidden_size)) + self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size)) + + self.post_init() + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int], + device: torch.device, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + device: (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: + shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and + value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are + used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape + `(batch_size, sequence_length)`. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + query_embeds = self.query_embeds.repeat(encoder_hidden_states.shape[0], 1, 1) + embedding_output = query_embeds + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask is None: + attention_mask = torch.ones( + (query_embeds.shape[0], query_embeds.shape[1]), dtype=torch.long, device=query_embeds.device + ) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = sequence_output[:, 0, :] + + sequence_output = self.visual_fc(sequence_output) + sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], dim=1) + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +if __name__ == "__main__": + from configuration_mplug_owl2 import MPLUGOwl2Config + config = MPLUGOwl2Config() + visual_model = MplugOwlVisionModel(config.visual_config["visual_model"]) + print(visual_model) + + abstractor_module = MplugOwlVisualAbstractorModel(config.visual_config["visual_abstractor"], config.hidden_size) + print(abstractor_module) \ No newline at end of file diff --git a/q_align/train/.ipynb_checkpoints/train-checkpoint.py b/q_align/train/.ipynb_checkpoints/train-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..d109848b1664625fdf099c50e210e1a93a909002 --- /dev/null +++ b/q_align/train/.ipynb_checkpoints/train-checkpoint.py @@ -0,0 +1,844 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +import torch + +import transformers +from transformers.models.clip.image_processing_clip import CLIPImageProcessor + +from torch.utils.data import Dataset +from q_align.train.mplug_owl2_trainer import MPLUGOwl2Trainer +from q_align.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN + +from q_align import conversation as conversation_lib +from q_align.model import * +from q_align.mm_utils import tokenizer_image_token + +from PIL import Image +from icecream import ic + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + image_grid_pinpoints: Optional[str] = field(default=None) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + + tune_visual_abstractor: bool = field(default=True) + freeze_vision_model: bool = field(default=True) + + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + visual_abstractor_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['vision_model', 'visual_abstractor'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + lora_module_names.add(name) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + + replace_token = DEFAULT_IMAGE_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]['value'] + source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + +# def __getitem__(self, i) -> Dict[str, torch.Tensor]: +# sources = self.list_data_dict[i] +# if isinstance(i, int): +# sources = [sources] +# assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME +# if 'image' in sources[0]: +# image_file = self.list_data_dict[i]['image'] +# image_folder = self.data_args.image_folder +# processor = self.data_args.image_processor +# image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') +# if self.data_args.image_aspect_ratio == 'pad': +# def expand2square(pil_img, background_color): +# width, height = pil_img.size +# if width == height: +# return pil_img +# elif width > height: +# result = Image.new(pil_img.mode, (width, width), background_color) +# result.paste(pil_img, (0, (width - height) // 2)) +# return result +# else: +# result = Image.new(pil_img.mode, (height, height), background_color) +# result.paste(pil_img, ((height - width) // 2, 0)) +# return result +# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) +# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] +# else: +# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] +# sources = preprocess_multimodal( +# copy.deepcopy([e["conversations"] for e in sources]), +# self.data_args) +# else: +# sources = copy.deepcopy([e["conversations"] for e in sources]) +# data_dict = preprocess( +# sources, +# self.tokenizer, +# has_image=('image' in self.list_data_dict[i])) +# if isinstance(i, int): +# data_dict = dict(input_ids=data_dict["input_ids"][0], +# labels=data_dict["labels"][0]) + +# # image exist in the data +# if 'image' in self.list_data_dict[i]: +# data_dict['image'] = image +# elif self.data_args.is_multimodal: +# # image does not exist in the data, but the model is multimodal +# crop_size = self.data_args.image_processor.crop_size +# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) +# return data_dict + + def next_rand(self): + import random + return random.randint(0,len(self)-1) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + while True: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + from pathlib import Path + #if not Path(os.path.join(image_folder, image_file)).exists(): + # i = self.next_rand() + # continue + if isinstance(image_file, list): + # Multiple Images as Input + try: + image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file] + except Exception as ex: + print(ex) + i = self.next_rand() + continue + if self.data_args.image_aspect_ratio == 'pad': + image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image] + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + elif os.path.join(image_folder, image_file).endswith("mp4"): + # Video as Input + image = load_video(os.path.join(image_folder, image_file)) + if self.data_args.image_aspect_ratio == 'pad': + image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image] + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + else: + try: + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + except Exception as ex: + print(ex) + i = self.next_rand() + continue + if self.data_args.image_aspect_ratio == 'pad': + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + crop_size = self.data_args.image_processor.crop_size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + model = MPLUGOwl2LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + + tokenizer.pad_token = tokenizer.unk_token + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if not training_args.freeze_vision_model and training_args.bits in [4, 8]: + model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device) + else: + vision_tower = model.get_model().vision_model + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + if training_args.tune_visual_abstractor and training_args.bits in [4, 8]: + model.get_model().visual_abstractor.to(dtype=compute_dtype, device=training_args.device) + else: + visual_abstractor = model.get_model().visual_abstractor + visual_abstractor.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = CLIPImageProcessor.from_pretrained(model_args.model_name_or_path) + data_args.is_multimodal = True + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = training_args.tune_visual_abstractor + ic(training_args.tune_visual_abstractor) + model.requires_grad_(True) + if training_args.tune_visual_abstractor: + # model.requires_grad_(False) + for p in model.get_model().visual_abstractor.parameters(): + p.requires_grad = True + + model.config.freeze_vision_model = training_args.freeze_vision_model + ic(training_args.freeze_vision_model) + if training_args.freeze_vision_model: + for p in model.get_model().vision_model.parameters(): + p.requires_grad = False + + model.config.visual_abstractor_lr = training_args.visual_abstractor_lr + + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = MPLUGOwl2Trainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + # trainer.train(resume_from_checkpoint=True) + # else: + # trainer.train() + + # TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE + trainer.train() + + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..526270d58a01315bbdac70f4465e54e94d0335c9 Binary files /dev/null and b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc differ diff --git a/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-311.pyc b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2b77240124a91ac5f378a6b3997814f71b03ae Binary files /dev/null and b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-311.pyc differ diff --git a/q_align/train/__pycache__/mplug_owl2_trainer.cpython-310.pyc b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb851af86edcd4279750449e470b5ed757f54d38 Binary files /dev/null and b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-310.pyc differ diff --git a/q_align/train/__pycache__/mplug_owl2_trainer.cpython-311.pyc b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9d7423b8c38926885b9376550027174522827c0 Binary files /dev/null and b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-311.pyc differ diff --git a/q_align/train/__pycache__/train.cpython-310.pyc b/q_align/train/__pycache__/train.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cc4f452551034733cf2096054fa152b0f490de0 Binary files /dev/null and b/q_align/train/__pycache__/train.cpython-310.pyc differ diff --git a/q_align/train/__pycache__/train.cpython-311.pyc b/q_align/train/__pycache__/train.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa2ed03477cf21175212b9c3dbdf46ee438b267 Binary files /dev/null and b/q_align/train/__pycache__/train.cpython-311.pyc differ diff --git a/q_align/train/llama_flash_attn_monkey_patch.py b/q_align/train/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..23bfdd8920b25581956356ca881f83070e2ae111 --- /dev/null +++ b/q_align/train/llama_flash_attn_monkey_patch.py @@ -0,0 +1,117 @@ +from typing import Optional, Tuple +import warnings + +import torch + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +except ImportError: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: bool = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states, modality_indicators) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states, modality_indicators) + .view(bsz, q_len, self.num_key_value_heads, self.head_dim) + .transpose(1, 2) + ) # shape: (b, num_heads, s, head_dim) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + # reuse k, v + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) + qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + max_s = q_len + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = output.view(bsz, q_len, -1) + else: + qkv = qkv.reshape(bsz, q_len, -1) + qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + output_unpad = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." + "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" + ) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward \ No newline at end of file diff --git a/q_align/train/mplug_owl2_trainer.py b/q_align/train/mplug_owl2_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..293dcdf21c82c715e663af11b90a19543244af18 --- /dev/null +++ b/q_align/train/mplug_owl2_trainer.py @@ -0,0 +1,243 @@ +import os +import torch + +from torch.utils.data import Sampler + +from transformers import Trainer +from transformers.trainer import ( + is_sagemaker_mp_enabled, + get_parameter_names, + has_length, + ALL_LAYERNORM_LAYERS, + ShardedDDPOption, + logger, +) +from typing import List, Optional +from icecream import ic + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + print(name, 'no ignore status') + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} + return to_return + + +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float("inf") + + return chunks + + +def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + assert all(l != 0 for l in lengths), "Should not have zero length." + if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): + # all samples are in the same modality + return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) + mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) + lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) + + mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] + lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] + megabatch_size = world_size * batch_size + mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] + lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] + + last_mm = mm_megabatches[-1] + last_lang = lang_megabatches[-1] + additional_batch = last_mm + last_lang + megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] + megabatch_indices = torch.randperm(len(megabatches), generator=generator) + megabatches = [megabatches[i] for i in megabatch_indices] + + if len(additional_batch) > 0: + megabatches.append(sorted(additional_batch)) + + return [i for megabatch in megabatches for i in megabatch] + + +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + lengths: Optional[List[int]] = None, + generator=None, + group_by_modality: bool = False, + ): + if lengths is None: + raise ValueError("Lengths must be provided.") + + self.batch_size = batch_size + self.world_size = world_size + self.lengths = lengths + self.generator = generator + self.group_by_modality = group_by_modality + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + if self.group_by_modality: + indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + else: + indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +class MPLUGOwl2Trainer(Trainer): + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + if self.args.group_by_modality_length: + lengths = self.train_dataset.modality_lengths + return LengthGroupedSampler( + self.args.train_batch_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, + lengths=lengths, + group_by_modality=True, + ) + else: + return super()._get_train_sampler() + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + if is_sagemaker_mp_enabled(): + return super().create_optimizer() + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + return super().create_optimizer() + + opt_model = self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + if self.args.visual_abstractor_lr is not None: + projector_parameters = [name for name, _ in opt_model.named_parameters() if "visual_abstractor_lr" in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + "lr": self.args.visual_abstractor_lr, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + "lr": self.args.visual_abstractor_lr, + }, + ] + else: + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params'])) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + return self.optimizer + + def _save_checkpoint(self, model, trial, metrics=None): + super(MPLUGOwl2Trainer, self)._save_checkpoint(model, trial, metrics) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + super(MPLUGOwl2Trainer, self)._save(output_dir, state_dict) \ No newline at end of file diff --git a/q_align/train/train.py b/q_align/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d109848b1664625fdf099c50e210e1a93a909002 --- /dev/null +++ b/q_align/train/train.py @@ -0,0 +1,844 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +from dataclasses import dataclass, field +import json +import logging +import pathlib +from typing import Dict, Optional, Sequence, List + +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +import torch + +import transformers +from transformers.models.clip.image_processing_clip import CLIPImageProcessor + +from torch.utils.data import Dataset +from q_align.train.mplug_owl2_trainer import MPLUGOwl2Trainer +from q_align.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN + +from q_align import conversation as conversation_lib +from q_align.model import * +from q_align.mm_utils import tokenizer_image_token + +from PIL import Image +from icecream import ic + +local_rank = None + + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") + version: Optional[str] = field(default="v0") + freeze_backbone: bool = field(default=False) + +@dataclass +class DataArguments: + data_path: str = field(default=None, + metadata={"help": "Path to the training data."}) + lazy_preprocess: bool = False + is_multimodal: bool = False + image_folder: Optional[str] = field(default=None) + image_aspect_ratio: str = 'square' + image_grid_pinpoints: Optional[str] = field(default=None) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + remove_unused_columns: bool = field(default=False) + + tune_visual_abstractor: bool = field(default=True) + freeze_vision_model: bool = field(default=True) + + model_max_length: int = field( + default=512, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)." + }, + ) + double_quant: bool = field( + default=True, + metadata={"help": "Compress the quantization statistics through double quantization."} + ) + quant_type: str = field( + default="nf4", + metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} + ) + bits: int = field( + default=16, + metadata={"help": "How many bits to use."} + ) + lora_enable: bool = False + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_weight_path: str = "" + lora_bias: str = "none" + visual_abstractor_lr: Optional[float] = None + group_by_modality_length: bool = field(default=False) + + +def maybe_zero_3(param, ignore_status=False, name=None): + from deepspeed import zero + from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus + if hasattr(param, "ds_id"): + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if not ignore_status: + logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} + return to_return + + +def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): + to_return = {k: t for k, t in named_params if "lora_" not in k} + if require_grad_only: + to_return = {k: t for k, t in to_return.items() if t.requires_grad} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): + to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} + to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} + return to_return + + +def find_all_linear_names(model): + cls = torch.nn.Linear + lora_module_names = set() + multimodal_keywords = ['vision_model', 'visual_abstractor'] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls): + lora_module_names.add(name) + + if 'lm_head' in lora_module_names: # needed for 16-bit + lora_module_names.remove('lm_head') + return list(lora_module_names) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, + output_dir: str): + """Collects the state dict and dump to disk.""" + + if trainer.deepspeed: + torch.cuda.synchronize() + trainer.save_model(output_dir) + return + + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = { + key: value.cpu() + for key, value in state_dict.items() + } + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) for text in strings + ] + input_ids = labels = [ + tokenized.input_ids[0] for tokenized in tokenized_list + ] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() + for tokenized in tokenized_list + ] + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def _mask_targets(target, tokenized_lens, speakers): + # cur_idx = 0 + cur_idx = tokenized_lens[0] + tokenized_lens = tokenized_lens[1:] + target[:cur_idx] = IGNORE_INDEX + for tokenized_len, speaker in zip(tokenized_lens, speakers): + if speaker == "human": + target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX + cur_idx += tokenized_len + + +def _add_speaker_and_signal(header, source, get_conversation=True): + """Add speaker and start/end signal on each round.""" + BEGIN_SIGNAL = "### " + END_SIGNAL = "\n" + conversation = header + for sentence in source: + from_str = sentence["from"] + if from_str.lower() == "human": + from_str = conversation_lib.default_conversation.roles[0] + elif from_str.lower() == "gpt": + from_str = conversation_lib.default_conversation.roles[1] + else: + from_str = 'unknown' + sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + + sentence["value"] + END_SIGNAL) + if get_conversation: + conversation += sentence["value"] + conversation += BEGIN_SIGNAL + return conversation + + +def preprocess_multimodal( + sources: Sequence[str], + data_args: DataArguments +) -> Dict: + is_multimodal = data_args.is_multimodal + if not is_multimodal: + return sources + + for source in sources: + for sentence in source: + if DEFAULT_IMAGE_TOKEN in sentence['value']: + sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] + sentence['value'] = sentence['value'].strip() + + replace_token = DEFAULT_IMAGE_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return sources + + +def preprocess_v1( + sources, + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + conv = conversation_lib.default_conversation.copy() + roles = {"human": conv.roles[0], "gpt": conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]["from"]] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence["from"]] + assert role == conv.roles[j % 2], f"{i}" + conv.append_message(role, sentence["value"]) + conversations.append(conv.get_prompt()) + + # Tokenize conversations + + if has_image: + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + else: + input_ids = tokenizer( + conversations, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + + targets = input_ids.clone() + + assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS + + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if has_image: + round_len = len(tokenizer_image_token(rou, tokenizer)) + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + else: + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + target[cur_len : cur_len + instruction_len] = IGNORE_INDEX + + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + print( + f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." + f" (ignored)" + ) + + return dict( + input_ids=input_ids, + labels=targets, + ) + + +def preprocess_plain( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + # add end signal and concatenate together + conversations = [] + for source in sources: + assert len(source) == 2 + assert DEFAULT_IMAGE_TOKEN in source[0]['value'] + source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep + conversations.append(conversation) + # tokenize conversations + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) + target[:tokenized_len] = IGNORE_INDEX + + return dict(input_ids=input_ids, labels=targets) + + +def preprocess( + sources: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + has_image: bool = False +) -> Dict: + """ + Given a list of sources, each is a conversation list. This transform: + 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; + 2. Concatenate conversations together; + 3. Tokenize the concatenated conversation; + 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. + """ + if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: + return preprocess_plain(sources, tokenizer) + if conversation_lib.default_conversation.version.startswith("v1"): + return preprocess_v1(sources, tokenizer, has_image=has_image) + # add end signal and concatenate together + conversations = [] + for source in sources: + header = f"{conversation_lib.default_conversation.system}\n\n" + conversation = _add_speaker_and_signal(header, source) + conversations.append(conversation) + # tokenize conversations + def get_tokenize_len(prompts): + return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] + if has_image: + input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] + else: + conversations_tokenized = _tokenize_fn(conversations, tokenizer) + input_ids = conversations_tokenized["input_ids"] + + targets = copy.deepcopy(input_ids) + for target, source in zip(targets, sources): + if has_image: + tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) + else: + tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] + speakers = [sentence["from"] for sentence in source] + _mask_targets(target, tokenized_lens, speakers) + + return dict(input_ids=input_ids, labels=targets) + + +def load_video(video_file): + from decord import VideoReader + vr = VideoReader(video_file) + + # Get video frame rate + fps = vr.get_avg_fps() + + # Calculate frame indices for 1fps + frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))] + frames = vr.get_batch(frame_indices).asnumpy() + return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))] + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_args: DataArguments): + super(LazySupervisedDataset, self).__init__() + list_data_dict = json.load(open(data_path, "r")) + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.list_data_dict = list_data_dict + self.data_args = data_args + + def __len__(self): + return len(self.list_data_dict) + + @property + def lengths(self): + length_list = [] + for sample in self.list_data_dict: + img_tokens = 128 if 'image' in sample else 0 + length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) + return length_list + + + @property + def modality_lengths(self): + length_list = [] + for sample in self.list_data_dict: + cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) + cur_len = cur_len if 'image' in sample else -cur_len + length_list.append(cur_len) + return length_list + +# def __getitem__(self, i) -> Dict[str, torch.Tensor]: +# sources = self.list_data_dict[i] +# if isinstance(i, int): +# sources = [sources] +# assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME +# if 'image' in sources[0]: +# image_file = self.list_data_dict[i]['image'] +# image_folder = self.data_args.image_folder +# processor = self.data_args.image_processor +# image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') +# if self.data_args.image_aspect_ratio == 'pad': +# def expand2square(pil_img, background_color): +# width, height = pil_img.size +# if width == height: +# return pil_img +# elif width > height: +# result = Image.new(pil_img.mode, (width, width), background_color) +# result.paste(pil_img, (0, (width - height) // 2)) +# return result +# else: +# result = Image.new(pil_img.mode, (height, height), background_color) +# result.paste(pil_img, ((height - width) // 2, 0)) +# return result +# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) +# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] +# else: +# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] +# sources = preprocess_multimodal( +# copy.deepcopy([e["conversations"] for e in sources]), +# self.data_args) +# else: +# sources = copy.deepcopy([e["conversations"] for e in sources]) +# data_dict = preprocess( +# sources, +# self.tokenizer, +# has_image=('image' in self.list_data_dict[i])) +# if isinstance(i, int): +# data_dict = dict(input_ids=data_dict["input_ids"][0], +# labels=data_dict["labels"][0]) + +# # image exist in the data +# if 'image' in self.list_data_dict[i]: +# data_dict['image'] = image +# elif self.data_args.is_multimodal: +# # image does not exist in the data, but the model is multimodal +# crop_size = self.data_args.image_processor.crop_size +# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) +# return data_dict + + def next_rand(self): + import random + return random.randint(0,len(self)-1) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + while True: + sources = self.list_data_dict[i] + if isinstance(i, int): + sources = [sources] + assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME + if 'image' in sources[0]: + image_file = self.list_data_dict[i]['image'] + + image_folder = self.data_args.image_folder + processor = self.data_args.image_processor + from pathlib import Path + #if not Path(os.path.join(image_folder, image_file)).exists(): + # i = self.next_rand() + # continue + if isinstance(image_file, list): + # Multiple Images as Input + try: + image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file] + except Exception as ex: + print(ex) + i = self.next_rand() + continue + if self.data_args.image_aspect_ratio == 'pad': + image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image] + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + elif os.path.join(image_folder, image_file).endswith("mp4"): + # Video as Input + image = load_video(os.path.join(image_folder, image_file)) + if self.data_args.image_aspect_ratio == 'pad': + image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image] + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'] + else: + try: + image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + except Exception as ex: + print(ex) + i = self.next_rand() + continue + if self.data_args.image_aspect_ratio == 'pad': + image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + sources = preprocess_multimodal( + copy.deepcopy([e["conversations"] for e in sources]), + self.data_args) + else: + + sources = copy.deepcopy([e["conversations"] for e in sources]) + data_dict = preprocess( + sources, + self.tokenizer, + has_image=('image' in self.list_data_dict[i])) + if isinstance(i, int): + data_dict = dict(input_ids=data_dict["input_ids"][0], + labels=data_dict["labels"][0]) + + # image exist in the data + if 'image' in self.list_data_dict[i]: + data_dict['image'] = image + elif self.data_args.is_multimodal: + # image does not exist in the data, but the model is multimodal + crop_size = self.data_args.image_processor.crop_size + data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) + return data_dict + + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] + for key in ("input_ids", "labels")) + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) + labels = torch.nn.utils.rnn.pad_sequence(labels, + batch_first=True, + padding_value=IGNORE_INDEX) + input_ids = input_ids[:, :self.tokenizer.model_max_length] + labels = labels[:, :self.tokenizer.model_max_length] + batch = dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + + if 'image' in instances[0]: + images = [instance['image'] for instance in instances] + if all(x is not None and x.shape == images[0].shape for x in images): + batch['images'] = torch.stack(images) + else: + batch['images'] = images + + return batch + + +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, + data_args) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + train_dataset = LazySupervisedDataset(tokenizer=tokenizer, + data_path=data_args.data_path, + data_args=data_args) + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + return dict(train_dataset=train_dataset, + eval_dataset=None, + data_collator=data_collator) + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + local_rank = training_args.local_rank + compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + + bnb_model_from_pretrained_args = {} + if training_args.bits in [4, 8]: + from transformers import BitsAndBytesConfig + bnb_model_from_pretrained_args.update(dict( + device_map={"": training_args.device}, + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + quantization_config=BitsAndBytesConfig( + load_in_4bit=training_args.bits == 4, + load_in_8bit=training_args.bits == 8, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=training_args.double_quant, + bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} + ) + )) + + model = MPLUGOwl2LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + **bnb_model_from_pretrained_args + ) + model.config.use_cache = False + + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + if training_args.bits in [4, 8]: + from peft import prepare_model_for_kbit_training + model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) + model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) + + if training_args.gradient_checkpointing: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if training_args.lora_enable: + from peft import LoraConfig, get_peft_model + lora_config = LoraConfig( + r=training_args.lora_r, + lora_alpha=training_args.lora_alpha, + target_modules=find_all_linear_names(model), + lora_dropout=training_args.lora_dropout, + bias=training_args.lora_bias, + task_type="CAUSAL_LM", + ) + if training_args.bits == 16: + if training_args.bf16: + model.to(torch.bfloat16) + if training_args.fp16: + model.to(torch.float16) + rank0_print("Adding LoRA adapters...") + model = get_peft_model(model, lora_config) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + + tokenizer.pad_token = tokenizer.unk_token + if model_args.version in conversation_lib.conv_templates: + conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] + else: + conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"] + + if not training_args.freeze_vision_model and training_args.bits in [4, 8]: + model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device) + else: + vision_tower = model.get_model().vision_model + vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + if training_args.tune_visual_abstractor and training_args.bits in [4, 8]: + model.get_model().visual_abstractor.to(dtype=compute_dtype, device=training_args.device) + else: + visual_abstractor = model.get_model().visual_abstractor + visual_abstractor.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) + + data_args.image_processor = CLIPImageProcessor.from_pretrained(model_args.model_name_or_path) + data_args.is_multimodal = True + + model.config.image_aspect_ratio = data_args.image_aspect_ratio + model.config.image_grid_pinpoints = data_args.image_grid_pinpoints + model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = training_args.tune_visual_abstractor + ic(training_args.tune_visual_abstractor) + model.requires_grad_(True) + if training_args.tune_visual_abstractor: + # model.requires_grad_(False) + for p in model.get_model().visual_abstractor.parameters(): + p.requires_grad = True + + model.config.freeze_vision_model = training_args.freeze_vision_model + ic(training_args.freeze_vision_model) + if training_args.freeze_vision_model: + for p in model.get_model().vision_model.parameters(): + p.requires_grad = False + + model.config.visual_abstractor_lr = training_args.visual_abstractor_lr + + + if training_args.bits in [4, 8]: + from peft.tuners.lora import LoraLayer + for name, module in model.named_modules(): + if isinstance(module, LoraLayer): + if training_args.bf16: + module = module.to(torch.bfloat16) + if 'norm' in name: + module = module.to(torch.float32) + if 'lm_head' in name or 'embed_tokens' in name: + if hasattr(module, 'weight'): + if training_args.bf16 and module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + data_module = make_supervised_data_module(tokenizer=tokenizer, + data_args=data_args) + trainer = MPLUGOwl2Trainer(model=model, + tokenizer=tokenizer, + args=training_args, + **data_module) + + # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + # trainer.train(resume_from_checkpoint=True) + # else: + # trainer.train() + + # TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE + trainer.train() + + trainer.save_state() + + model.config.use_cache = True + + if training_args.lora_enable: + state_dict = get_peft_state_maybe_zero_3( + model.named_parameters(), training_args.lora_bias + ) + non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( + model.named_parameters() + ) + if training_args.local_rank == 0 or training_args.local_rank == -1: + model.config.save_pretrained(training_args.output_dir) + model.save_pretrained(training_args.output_dir, state_dict=state_dict) + torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) + else: + safe_save_model_for_hf_trainer(trainer=trainer, + output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/q_align/train/train_mem.py b/q_align/train/train_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..4698657df8e14de319a0c117948eb8fd8a4c67b3 --- /dev/null +++ b/q_align/train/train_mem.py @@ -0,0 +1,13 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. + +# Need to call this before importing transformers. +from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + +replace_llama_attn_with_flash_attn() + +from q_align.train.train import train + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/q_align/utils.py b/q_align/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..378dbc81cce765dea2cd138ae50f242fc557c2c9 --- /dev/null +++ b/q_align/utils.py @@ -0,0 +1,128 @@ +import datetime +import logging +import logging.handlers +import os +import sys + +import requests + + + +from q_align.constants import LOGDIR + +server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." + +handler = None + + +def build_logger(logger_name, logger_filename): + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + stderr_logger = logging.getLogger("stderr") + stderr_logger.setLevel(logging.ERROR) + sl = StreamToLogger(stderr_logger, logging.ERROR) + sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(LOGDIR, exist_ok=True) + filename = os.path.join(LOGDIR, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when='D', utc=True) + handler.setFormatter(formatter) + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = '' + + def __getattr__(self, attr): + return getattr(self.terminal, attr) + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = '' + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == '\n': + self.logger.log(self.log_level, line.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != '': + self.logger.log(self.log_level, self.linebuf.rstrip()) + self.linebuf = '' + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = {"Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" \ No newline at end of file