diff --git a/app.py b/app.py
index c185365f71f3912b4d5ca1f075abc21040c935e5..63eb8ba17e3402f2912ad73c3517cce77c9681a8 100644
--- a/app.py
+++ b/app.py
@@ -1,3 +1,119 @@
import gradio as gr
-gr.load("models/q-future/one-align").launch()
\ No newline at end of file
+import argparse
+import datetime
+import json
+import os
+import time
+
+import gradio as gr
+import requests
+from PIL import Image
+
+from q_align.model.builder import load_pretrained_model
+
+from q_align.conversation import (default_conversation, conv_templates,
+ SeparatorStyle)
+from q_align.constants import LOGDIR
+from q_align.utils import (build_logger, server_error_msg,
+ violates_moderation, moderation_msg)
+
+from q_align.evaluate.scorer import QAlignScorer, QAlignAestheticScorer, QAlignVideoScorer
+
+import gradio as gr
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+pretrained="q-future/one-align"
+device="cuda:0"
+tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+
+iqa_scorer = QAlignScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
+iaa_scorer = QAlignAestheticScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
+vqa_scorer = QAlignVideoScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
+
+scorers = {"Image Aesthetics (IAA)": iaa_scorer, "Image Quality (IQA)": iqa_scorer, "Video Quality (VQA)": vqa_scorer}
+
+LEVELS = ["excellent (5)", "good (4)", "fair (3)", "poor (2)", "bad (1)"]
+scores = [5,4,3,2,1]
+def image_classifier(input_img, input_vid, scorer_type):
+ if scorer_type is None:
+ scorer_type = "Image Quality (IQA)"
+ this_scorer = scorers[scorer_type]
+ if input_vid is not None:
+ input_ = load_video(input_vid)
+ elif input_img is not None:
+ input_ = [input_img]
+ if "Video" in scorer_type:
+ input_ = [input_]
+ probs = this_scorer(input_).mean(0).tolist()
+ prob_dict = {LEVEL: prob for LEVEL, prob in zip(LEVELS, probs)}
+ score = sum([prob * score for score, prob in zip(scores, probs)])
+ return prob_dict, score
+
+title_markdown = ("""
+
+
Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined Levels
+
+ One Unified Model for Visual scoring.
+
+
+
+
+
+
+
+ 1Nanyang Technological University, 2Shanghai Jiao Tong University, 3Sensetime Research
+
+
+*Equal contribution. +Project Lead. #Corresponding author(s).
+
+
+ If you like the OneScorer, please give us a star ✨ on GitHub for latest update.
+
+
+
+
+
+""")
+
+
+input_img = gr.Image(type='pil', label="Upload an Image")
+input_vid = gr.Video(label="Upload a Video (will INGORE the image if a video is uploaded)", info="If a video is uploaded, the image uploaded will be ignored.")
+
+labels = gr.Label(label="Probabilities of rating levels:")
+number = gr.Number(label="Output score:", info="Range in [1,5]. Higher is better.")
+demo = gr.Interface(fn=image_classifier, inputs=[input_img, input_vid, gr.Radio(["Image Aesthetics (IAA)", "Image Quality (IQA)", "Video Quality (VQA)"], label="Task", info="Which Scorer will you need?"),], outputs=[labels, number], title="OneScorer", description=title_markdown)
+demo.launch(share=True)
\ No newline at end of file
diff --git a/q_align/.ipynb_checkpoints/utils-checkpoint.py b/q_align/.ipynb_checkpoints/utils-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..378dbc81cce765dea2cd138ae50f242fc557c2c9
--- /dev/null
+++ b/q_align/.ipynb_checkpoints/utils-checkpoint.py
@@ -0,0 +1,128 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+
+
+from q_align.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
\ No newline at end of file
diff --git a/q_align/__init__.py b/q_align/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0e4234303e2ad423b55342b0000d9dfc5b88df3
--- /dev/null
+++ b/q_align/__init__.py
@@ -0,0 +1 @@
+from .model import MPLUGOwl2LlamaForCausalLM
\ No newline at end of file
diff --git a/q_align/__pycache__/__init__.cpython-310.pyc b/q_align/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a7db7c91534d4d9fcba53458c73de8b7b93319c0
Binary files /dev/null and b/q_align/__pycache__/__init__.cpython-310.pyc differ
diff --git a/q_align/__pycache__/__init__.cpython-311.pyc b/q_align/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43eabac5150ca8ac52e0fe20c593ea49d166b30c
Binary files /dev/null and b/q_align/__pycache__/__init__.cpython-311.pyc differ
diff --git a/q_align/__pycache__/__init__.cpython-39.pyc b/q_align/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9711d49034200f623d0a4231eca427b704a13af9
Binary files /dev/null and b/q_align/__pycache__/__init__.cpython-39.pyc differ
diff --git a/q_align/__pycache__/constants.cpython-310.pyc b/q_align/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..40d379600f04196fc4637597294b0a48dd6c4edc
Binary files /dev/null and b/q_align/__pycache__/constants.cpython-310.pyc differ
diff --git a/q_align/__pycache__/constants.cpython-311.pyc b/q_align/__pycache__/constants.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ed26139984658aa167829ebb17321ac8eabb7f79
Binary files /dev/null and b/q_align/__pycache__/constants.cpython-311.pyc differ
diff --git a/q_align/__pycache__/constants.cpython-39.pyc b/q_align/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a78884864b830cccdb26eb71c54b6f74944533e5
Binary files /dev/null and b/q_align/__pycache__/constants.cpython-39.pyc differ
diff --git a/q_align/__pycache__/conversation.cpython-310.pyc b/q_align/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e3817f60cf760318094a82981eb13d9f39b975d
Binary files /dev/null and b/q_align/__pycache__/conversation.cpython-310.pyc differ
diff --git a/q_align/__pycache__/conversation.cpython-311.pyc b/q_align/__pycache__/conversation.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..570ae579548d6bcbe60e21df289d19c68aa03644
Binary files /dev/null and b/q_align/__pycache__/conversation.cpython-311.pyc differ
diff --git a/q_align/__pycache__/conversation.cpython-39.pyc b/q_align/__pycache__/conversation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29aa67455004b8251e5a63a880ff144831a37a87
Binary files /dev/null and b/q_align/__pycache__/conversation.cpython-39.pyc differ
diff --git a/q_align/__pycache__/mm_utils.cpython-310.pyc b/q_align/__pycache__/mm_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..519d0457d58e50de78d109b3004e2464a6d67d2c
Binary files /dev/null and b/q_align/__pycache__/mm_utils.cpython-310.pyc differ
diff --git a/q_align/__pycache__/mm_utils.cpython-311.pyc b/q_align/__pycache__/mm_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47702f8c81d7079d18936516385cc5a4ddad53bb
Binary files /dev/null and b/q_align/__pycache__/mm_utils.cpython-311.pyc differ
diff --git a/q_align/__pycache__/mm_utils.cpython-39.pyc b/q_align/__pycache__/mm_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a14b149ddd57f0c91444b14f6719c28b8bac6b9f
Binary files /dev/null and b/q_align/__pycache__/mm_utils.cpython-39.pyc differ
diff --git a/q_align/__pycache__/utils.cpython-311.pyc b/q_align/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abd682a6a83a001bf72d71cd273b0272972fb6c7
Binary files /dev/null and b/q_align/__pycache__/utils.cpython-311.pyc differ
diff --git a/q_align/constants.py b/q_align/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..b632a10f2c053c72fae8d61a1fa10fa52aa0f4e8
--- /dev/null
+++ b/q_align/constants.py
@@ -0,0 +1,9 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "./demo_logs"
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_IMAGE_TOKEN = "<|image|>"
diff --git a/q_align/conversation.py b/q_align/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9f6f3f518f8b4426493ff7dacb0f6b70ae56c09
--- /dev/null
+++ b/q_align/conversation.py
@@ -0,0 +1,301 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+from q_align.constants import DEFAULT_IMAGE_TOKEN
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ TWO_NO_SYS = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ # init_msg = init_msg[0].replace("", "").strip()
+ # if 'mmtag' in self.version:
+ # messages[0] = (init_role, init_msg)
+ # messages.insert(0, (self.roles[0], ""))
+ # messages.insert(1, (self.roles[1], "Received."))
+ # else:
+ # messages[0] = (init_role, "\n" + init_msg)
+ init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
+ seps = [self.sep, self.sep2]
+ ret = ""
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n"
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if longest_edge != max(image.size):
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ img_str = f''
+ msg = img_str + msg.replace('<|image|>', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_mplug_owl2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO_NO_SYS,
+ sep=" ",
+ sep2="",
+)
+
+# default_conversation = conv_vicuna_v1
+default_conversation = conv_mplug_owl2
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "mplug_owl2": conv_mplug_owl2,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
\ No newline at end of file
diff --git a/q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..64f17e29dacae2c145c0b10d1fe0dcdc27b8f6b9
--- /dev/null
+++ b/q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py
@@ -0,0 +1,164 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+from scipy.stats import spearmanr, pearsonr
+
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+def wa5(logits):
+ import numpy as np
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
+
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_image(image_file):
+ if image_file.startswith('http://') or image_file.startswith('https://'):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_file).convert('RGB')
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_path = "playground/data/"
+
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ json_prefix + "test_ava.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the aesthetics of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = DEFAULT_IMAGE_TOKEN + inp
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The aesthetics of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for json_ in jsons:
+ with open(json_) as f:
+ iqadata = json.load(f)
+
+ image_tensors = []
+ batch_data = []
+ prs, gts = [], []
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ filename = llddata["image"]
+ llddata["logits"] = defaultdict(float)
+
+
+
+ image = load_image(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+ image_tensors.append(image_tensor)
+ batch_data.append(llddata)
+
+ if i % 8 == 7 or i == len(iqadata) - 1:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
+
+ for j, xllddata in enumerate(batch_data):
+ for tok, id_ in zip(toks, ids_):
+ xllddata["logits"][tok] += output_logits[j,id_].item()
+ xllddata["score"] = wa5(xllddata["logits"])
+ # print(llddata)
+ prs.append(xllddata["score"])
+ gts.append(xllddata["gt_score"])
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
+ json.dump(xllddata, wf)
+
+ image_tensors = []
+ batch_data = []
+
+ #if i > 0 and i % 200 == 0:
+ # print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aaf1a7ac9ec4ee21a48ee96bfc92fe742559c24
--- /dev/null
+++ b/q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py
@@ -0,0 +1,150 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+from decord import VideoReader
+
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_video(video_file):
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_paths = [
+ "playground/data/",
+ "playground/data/",
+ "playground/data/KoNViD_1k_videos/",
+ "playground/data/maxwell/",
+ ]
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ json_prefix + "test_lsvq.json",
+ json_prefix + "test_lsvq_1080p.json",
+ json_prefix + "konvid.json",
+ json_prefix + "maxwell_test.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the quality of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for image_path, json_ in zip(image_paths, jsons):
+ with open(json_) as f:
+ iqadata = json.load(f)
+ try:
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ filename = llddata["img_path"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_video(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+
+ if True:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
+ images=image_tensor)["logits"][:,-1]
+
+ for tok, id_ in zip(toks, ids_):
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
+ json.dump(llddata, wf)
+ except:
+ continue
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/q-align-image")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e0695d696cdad4b277aab3a753dc4b9dce89af1
--- /dev/null
+++ b/q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py
@@ -0,0 +1,156 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_image(image_file):
+ if image_file.startswith('http://') or image_file.startswith('https://'):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_file).convert('RGB')
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_path = "playground/data/"
+
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ json_prefix + "test_imagerewarddb.json",
+ json_prefix + "test_koniq.json",
+ json_prefix + "test_spaq.json",
+ json_prefix + "test_kadid.json",
+ json_prefix + "livec.json",
+ json_prefix + "agi.json",
+ json_prefix + "live.json",
+ json_prefix + "csiq.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for json_ in jsons:
+ with open(json_) as f:
+ iqadata = json.load(f)
+
+ image_tensors = []
+ batch_data = []
+
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ if True:
+ try:
+ filename = llddata["image"]
+ except:
+ filename = llddata["img_path"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_image(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+ image_tensors.append(image_tensor)
+ batch_data.append(llddata)
+
+ if i % 8 == 7 or i == len(iqadata) - 1:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
+
+ for j, xllddata in enumerate(batch_data):
+ for tok, id_ in zip(toks, ids_):
+ xllddata["logits"][tok] += output_logits[j,id_].item()
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
+ json.dump(xllddata, wf)
+
+ image_tensors = []
+ batch_data = []
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d68f83df09a0f8696a49bede9844d7e7d6465dd
--- /dev/null
+++ b/q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py
@@ -0,0 +1,155 @@
+from PIL import Image
+
+import torch.nn as nn
+import torch
+
+from typing import List
+
+from q_align.model.builder import load_pretrained_model
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+class QAlignScorer(nn.Module):
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
+ super().__init__()
+ if model is None:
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+ prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
+
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
+
+ self.tokenizer = tokenizer
+ self.model = model
+ self.image_processor = image_processor
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
+
+ def expand2square(self, pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ def forward(self, image: List[Image.Image]):
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
+ with torch.inference_mode():
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
+
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
+
+
+class QAlignAestheticScorer(nn.Module):
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
+ super().__init__()
+ if model is None:
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+ prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is"
+
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
+
+ self.tokenizer = tokenizer
+ self.model = model
+ self.image_processor = image_processor
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
+
+ def expand2square(self, pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ def forward(self, image: List[Image.Image]):
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
+ with torch.inference_mode():
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
+
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
+
+class QAlignVideoScorer(nn.Module):
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
+ super().__init__()
+ if model is None:
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+ prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is"
+
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
+
+ self.tokenizer = tokenizer
+ self.model = model
+ self.image_processor = image_processor
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
+
+ def expand2square(self, pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ def forward(self, video: List[List[Image.Image]]):
+ video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video]
+ with torch.inference_mode():
+ video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
+ output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
+ images=video_tensors)["logits"][:,-1, self.preferential_ids_]
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
+ parser.add_argument("--aesthetic", action="store_true")
+ parser.add_argument("--video", action="store_true")
+ args = parser.parse_args()
+
+ if args.video:
+ scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device)
+ print(scorer([load_video(args.img_path)]).tolist())
+ else:
+ scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device)
+ print(scorer([Image.open(args.img_path)]).tolist())
+
diff --git a/q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py b/q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb69fea96cf97ed9fd41554460026f0caa34614
--- /dev/null
+++ b/q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py
@@ -0,0 +1,167 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+
+from scipy.stats import spearmanr, pearsonr
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+def wa5(logits):
+ import numpy as np
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_paths = [
+ #"playground/data/",
+ #"playground/data/",
+ "playground/data/KoNViD_1k_videos/",
+ "playground/data/maxwell/",
+
+ ]
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ #json_prefix + "test_lsvq.json",
+ #json_prefix + "test_lsvq_1080p.json",
+ json_prefix + "konvid.json",
+ json_prefix + "maxwell_test.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the quality of this video?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the video is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for image_path, json_ in zip(image_paths, jsons):
+ with open(json_) as f:
+ iqadata = json.load(f)
+ prs, gts = [], []
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ try:
+ try:
+ filename = llddata["img_path"]
+ except:
+ filename = llddata["image"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_video(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+ if True:
+ with torch.inference_mode():
+ output_logits = model(input_ids,
+ images=[image_tensor])["logits"][:,-1]
+ for tok, id_ in zip(toks, ids_):
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
+ llddata["score"] = wa5(llddata["logits"])
+ # print(llddata)
+ prs.append(llddata["score"])
+ gts.append(llddata["gt_score"])
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
+ json.dump(llddata, wf)
+
+ if i > 0 and i % 200 == 0:
+ print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
+ except:
+ continue
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/__pycache__/scorer.cpython-311.pyc b/q_align/evaluate/__pycache__/scorer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b421b86aed039452e511142014200bb403f9e659
Binary files /dev/null and b/q_align/evaluate/__pycache__/scorer.cpython-311.pyc differ
diff --git a/q_align/evaluate/eval.py b/q_align/evaluate/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..008b7b07abfec6c41a1162a582dae2ccdae5f597
--- /dev/null
+++ b/q_align/evaluate/eval.py
@@ -0,0 +1,138 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_image(image_file):
+ if image_file.startswith('http://') or image_file.startswith('https://'):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_file).convert('RGB')
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_path = "playground/data/"
+
+ json_prefix = "playground/data/labels/mos_simple/"
+ jsons = [
+ json_prefix + "test_flive.json",
+ json_prefix + "combined/kadid_ref.json",
+ json_prefix + "combined/livec.json",
+ json_prefix + "test_koniq.json",
+ json_prefix + "test_spaq.json",
+ json_prefix + "combined/agi.json",
+ json_prefix + "combined/kadid.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the quality of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
+
+ for json_ in jsons:
+ with open(json_) as f:
+ iqadata = json.load(f)
+
+ image_tensors = []
+ batch_data = []
+
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ #print(f"Evaluating image {i}")
+ #print(prompt)
+ filename = llddata["image"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_image(image_path + filename)
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
+
+ image_tensors.append(image_tensor)
+ batch_data.append(llddata)
+
+ if i % 8 == 7 or i == len(iqadata) - 1:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
+
+ for j, xllddata in enumerate(batch_data):
+ for tok, id_ in zip(toks, ids_):
+ xllddata["logits"][tok] += output_logits[j,id_].item()
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ # print(f"results/mix-mplug-owl-2-boost_iqa_wu_v2/{json_.split('/')[-1]}")
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
+ json.dump(xllddata, wf)
+
+ image_tensors = []
+ batch_data = []
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/q-align-koniq-spaq-v0")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/iaa_eval.py b/q_align/evaluate/iaa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..64f17e29dacae2c145c0b10d1fe0dcdc27b8f6b9
--- /dev/null
+++ b/q_align/evaluate/iaa_eval.py
@@ -0,0 +1,164 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+from scipy.stats import spearmanr, pearsonr
+
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+def wa5(logits):
+ import numpy as np
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
+
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_image(image_file):
+ if image_file.startswith('http://') or image_file.startswith('https://'):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_file).convert('RGB')
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_path = "playground/data/"
+
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ json_prefix + "test_ava.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the aesthetics of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = DEFAULT_IMAGE_TOKEN + inp
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The aesthetics of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for json_ in jsons:
+ with open(json_) as f:
+ iqadata = json.load(f)
+
+ image_tensors = []
+ batch_data = []
+ prs, gts = [], []
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ filename = llddata["image"]
+ llddata["logits"] = defaultdict(float)
+
+
+
+ image = load_image(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+ image_tensors.append(image_tensor)
+ batch_data.append(llddata)
+
+ if i % 8 == 7 or i == len(iqadata) - 1:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
+
+ for j, xllddata in enumerate(batch_data):
+ for tok, id_ in zip(toks, ids_):
+ xllddata["logits"][tok] += output_logits[j,id_].item()
+ xllddata["score"] = wa5(xllddata["logits"])
+ # print(llddata)
+ prs.append(xllddata["score"])
+ gts.append(xllddata["gt_score"])
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
+ json.dump(xllddata, wf)
+
+ image_tensors = []
+ batch_data = []
+
+ #if i > 0 and i % 200 == 0:
+ # print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/iqa4vqa_eval.py b/q_align/evaluate/iqa4vqa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aaf1a7ac9ec4ee21a48ee96bfc92fe742559c24
--- /dev/null
+++ b/q_align/evaluate/iqa4vqa_eval.py
@@ -0,0 +1,150 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+from decord import VideoReader
+
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_video(video_file):
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_paths = [
+ "playground/data/",
+ "playground/data/",
+ "playground/data/KoNViD_1k_videos/",
+ "playground/data/maxwell/",
+ ]
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ json_prefix + "test_lsvq.json",
+ json_prefix + "test_lsvq_1080p.json",
+ json_prefix + "konvid.json",
+ json_prefix + "maxwell_test.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the quality of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for image_path, json_ in zip(image_paths, jsons):
+ with open(json_) as f:
+ iqadata = json.load(f)
+ try:
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ filename = llddata["img_path"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_video(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+
+ if True:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
+ images=image_tensor)["logits"][:,-1]
+
+ for tok, id_ in zip(toks, ids_):
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
+ json.dump(llddata, wf)
+ except:
+ continue
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/q-align-image")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/iqa_eval.py b/q_align/evaluate/iqa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e0695d696cdad4b277aab3a753dc4b9dce89af1
--- /dev/null
+++ b/q_align/evaluate/iqa_eval.py
@@ -0,0 +1,156 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_image(image_file):
+ if image_file.startswith('http://') or image_file.startswith('https://'):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_file).convert('RGB')
+ return image
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_path = "playground/data/"
+
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ json_prefix + "test_imagerewarddb.json",
+ json_prefix + "test_koniq.json",
+ json_prefix + "test_spaq.json",
+ json_prefix + "test_kadid.json",
+ json_prefix + "livec.json",
+ json_prefix + "agi.json",
+ json_prefix + "live.json",
+ json_prefix + "csiq.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the image is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for json_ in jsons:
+ with open(json_) as f:
+ iqadata = json.load(f)
+
+ image_tensors = []
+ batch_data = []
+
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ if True:
+ try:
+ filename = llddata["image"]
+ except:
+ filename = llddata["img_path"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_image(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+ image_tensors.append(image_tensor)
+ batch_data.append(llddata)
+
+ if i % 8 == 7 or i == len(iqadata) - 1:
+ with torch.inference_mode():
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
+
+ for j, xllddata in enumerate(batch_data):
+ for tok, id_ in zip(toks, ids_):
+ xllddata["logits"][tok] += output_logits[j,id_].item()
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
+ json.dump(xllddata, wf)
+
+ image_tensors = []
+ batch_data = []
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/evaluate/scorer.py b/q_align/evaluate/scorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d68f83df09a0f8696a49bede9844d7e7d6465dd
--- /dev/null
+++ b/q_align/evaluate/scorer.py
@@ -0,0 +1,155 @@
+from PIL import Image
+
+import torch.nn as nn
+import torch
+
+from typing import List
+
+from q_align.model.builder import load_pretrained_model
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+class QAlignScorer(nn.Module):
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
+ super().__init__()
+ if model is None:
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+ prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
+
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
+
+ self.tokenizer = tokenizer
+ self.model = model
+ self.image_processor = image_processor
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
+
+ def expand2square(self, pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ def forward(self, image: List[Image.Image]):
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
+ with torch.inference_mode():
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
+
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
+
+
+class QAlignAestheticScorer(nn.Module):
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
+ super().__init__()
+ if model is None:
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+ prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is"
+
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
+
+ self.tokenizer = tokenizer
+ self.model = model
+ self.image_processor = image_processor
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
+
+ def expand2square(self, pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ def forward(self, image: List[Image.Image]):
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
+ with torch.inference_mode():
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
+
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
+
+class QAlignVideoScorer(nn.Module):
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
+ super().__init__()
+ if model is None:
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
+ prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is"
+
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
+
+ self.tokenizer = tokenizer
+ self.model = model
+ self.image_processor = image_processor
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
+
+ def expand2square(self, pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ def forward(self, video: List[List[Image.Image]]):
+ video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video]
+ with torch.inference_mode():
+ video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
+ output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
+ images=video_tensors)["logits"][:,-1, self.preferential_ids_]
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
+ parser.add_argument("--aesthetic", action="store_true")
+ parser.add_argument("--video", action="store_true")
+ args = parser.parse_args()
+
+ if args.video:
+ scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device)
+ print(scorer([load_video(args.img_path)]).tolist())
+ else:
+ scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device)
+ print(scorer([Image.open(args.img_path)]).tolist())
+
diff --git a/q_align/evaluate/vqa_eval.py b/q_align/evaluate/vqa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cb69fea96cf97ed9fd41554460026f0caa34614
--- /dev/null
+++ b/q_align/evaluate/vqa_eval.py
@@ -0,0 +1,167 @@
+import argparse
+import torch
+
+from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from q_align.conversation import conv_templates, SeparatorStyle
+from q_align.model.builder import load_pretrained_model
+from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+from transformers import TextStreamer
+
+
+from scipy.stats import spearmanr, pearsonr
+
+import json
+from tqdm import tqdm
+from collections import defaultdict
+
+import os
+
+def wa5(logits):
+ import numpy as np
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
+
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+
+def main(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
+
+
+ import json
+
+
+ image_paths = [
+ #"playground/data/",
+ #"playground/data/",
+ "playground/data/KoNViD_1k_videos/",
+ "playground/data/maxwell/",
+
+ ]
+
+ json_prefix = "playground/data/test_jsons/"
+ jsons = [
+ #json_prefix + "test_lsvq.json",
+ #json_prefix + "test_lsvq_1080p.json",
+ json_prefix + "konvid.json",
+ json_prefix + "maxwell_test.json",
+ ]
+
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
+
+
+ conv_mode = "mplug_owl2"
+
+ inp = "How would you rate the quality of this video?"
+
+ conv = conv_templates[conv_mode].copy()
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
+ conv.append_message(conv.roles[0], inp)
+ image = None
+
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt() + " The quality of the video is"
+
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
+ print(toks)
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
+ print(ids_)
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
+
+ for image_path, json_ in zip(image_paths, jsons):
+ with open(json_) as f:
+ iqadata = json.load(f)
+ prs, gts = [], []
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
+ try:
+ try:
+ filename = llddata["img_path"]
+ except:
+ filename = llddata["image"]
+ llddata["logits"] = defaultdict(float)
+
+ image = load_video(image_path + filename)
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
+
+ if True:
+ with torch.inference_mode():
+ output_logits = model(input_ids,
+ images=[image_tensor])["logits"][:,-1]
+ for tok, id_ in zip(toks, ids_):
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
+ llddata["score"] = wa5(llddata["logits"])
+ # print(llddata)
+ prs.append(llddata["score"])
+ gts.append(llddata["gt_score"])
+ # print(llddata)
+ json_ = json_.replace("combined/", "combined-")
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
+ json.dump(llddata, wf)
+
+ if i > 0 and i % 200 == 0:
+ print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
+ except:
+ continue
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--device", type=str, default="cuda:0")
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--max-new-tokens", type=int, default=512)
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--debug", action="store_true")
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/q_align/mm_utils.py b/q_align/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee1ee281029576bf0f23384ebb79a34e2f34d578
--- /dev/null
+++ b/q_align/mm_utils.py
@@ -0,0 +1,112 @@
+from PIL import Image
+from io import BytesIO
+import base64
+
+import torch
+from transformers import StoppingCriteria
+from q_align.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
+from icecream import ic
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg=None):
+ if model_cfg is not None:
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ else:
+ image_aspect_ratio = 'resize'
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ elif image_aspect_ratio == 'resize':
+ for image in images:
+ max_edge = max(image.size)
+ image = image.resize((max_edge, max_edge))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == 'pt':
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+
+
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
\ No newline at end of file
diff --git a/q_align/model/__init__.py b/q_align/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d6f0775a0abb2c3e220343a4feb05c70c2c7779
--- /dev/null
+++ b/q_align/model/__init__.py
@@ -0,0 +1,2 @@
+from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
+from .configuration_mplug_owl2 import MPLUGOwl2Config
\ No newline at end of file
diff --git a/q_align/model/__pycache__/__init__.cpython-310.pyc b/q_align/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2385b4122537f406b59b779bbee7db97400c7970
Binary files /dev/null and b/q_align/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/__init__.cpython-311.pyc b/q_align/model/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a0f4f90f4ca15acc6fce979cdde7c09a3854252
Binary files /dev/null and b/q_align/model/__pycache__/__init__.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/__init__.cpython-39.pyc b/q_align/model/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58122c87b9917b19fc6d37f5b4d9d6e56165fe5f
Binary files /dev/null and b/q_align/model/__pycache__/__init__.cpython-39.pyc differ
diff --git a/q_align/model/__pycache__/builder.cpython-310.pyc b/q_align/model/__pycache__/builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..105a98b56eb0ec9844c7e91ebe97a3859c6da159
Binary files /dev/null and b/q_align/model/__pycache__/builder.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/builder.cpython-311.pyc b/q_align/model/__pycache__/builder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..433a8f78c22f1bb8cbed44a8212c7a135b6984cf
Binary files /dev/null and b/q_align/model/__pycache__/builder.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/builder.cpython-39.pyc b/q_align/model/__pycache__/builder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..29c5b7064f5ff98c67a0ce57a3a00f334451d0bb
Binary files /dev/null and b/q_align/model/__pycache__/builder.cpython-39.pyc differ
diff --git a/q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..635940c1e254c29f31bf99267093652e79be3cde
Binary files /dev/null and b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25719ada8a0c968af25010fbd5e5210e41970fe1
Binary files /dev/null and b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5900c0e16a8406f291054065b730695ad319ed2
Binary files /dev/null and b/q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc differ
diff --git a/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5696d2b89bfcbca12fb90f41d4659948835bc8e5
Binary files /dev/null and b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7f214501d903dd1fce9af34f44eb926aa7f9c151
Binary files /dev/null and b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7205da42d472fbe7c21978b1a00eb1c80d48c13e
Binary files /dev/null and b/q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc differ
diff --git a/q_align/model/__pycache__/modeling_llama2.cpython-310.pyc b/q_align/model/__pycache__/modeling_llama2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf3b452eb120657e9fdc922c2d9af36c4779a486
Binary files /dev/null and b/q_align/model/__pycache__/modeling_llama2.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/modeling_llama2.cpython-311.pyc b/q_align/model/__pycache__/modeling_llama2.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12c41d1db6becf19b2c92e71fcbcc644df70ff40
Binary files /dev/null and b/q_align/model/__pycache__/modeling_llama2.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/modeling_llama2.cpython-39.pyc b/q_align/model/__pycache__/modeling_llama2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf4d5c16c742f3034faa01d7c025693d9145f074
Binary files /dev/null and b/q_align/model/__pycache__/modeling_llama2.cpython-39.pyc differ
diff --git a/q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..57e4fcfc33a50f58b7c7f9728799eed476f64acc
Binary files /dev/null and b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d8a41ee61a9e40ffc92f2390b655242ad5224ff
Binary files /dev/null and b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7f37ea2a8e9a32491babc6fa09766cd462475e4
Binary files /dev/null and b/q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc differ
diff --git a/q_align/model/__pycache__/visual_encoder.cpython-310.pyc b/q_align/model/__pycache__/visual_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5afa32d185606d4e67fa234c37fa7d3dd1a95ea9
Binary files /dev/null and b/q_align/model/__pycache__/visual_encoder.cpython-310.pyc differ
diff --git a/q_align/model/__pycache__/visual_encoder.cpython-311.pyc b/q_align/model/__pycache__/visual_encoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc26e07486aa49227b346f1940189272ff91b0d7
Binary files /dev/null and b/q_align/model/__pycache__/visual_encoder.cpython-311.pyc differ
diff --git a/q_align/model/__pycache__/visual_encoder.cpython-39.pyc b/q_align/model/__pycache__/visual_encoder.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..32fda8db614444b38bd7fb71951950478fd19487
Binary files /dev/null and b/q_align/model/__pycache__/visual_encoder.cpython-39.pyc differ
diff --git a/q_align/model/builder.py b/q_align/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5da53f5ff3d39265636f154923173c49186e7fa4
--- /dev/null
+++ b/q_align/model/builder.py
@@ -0,0 +1,118 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+import shutil
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+from transformers.models.clip.image_processing_clip import CLIPImageProcessor
+import torch
+from q_align.model import *
+from icecream import ic
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
+ kwargs = {"device_map": device_map}
+
+ if device != "cuda":
+ kwargs['device_map'] = {"": device}
+
+ if load_8bit:
+ kwargs['load_in_8bit'] = True
+ elif load_4bit:
+ kwargs['load_in_4bit'] = True
+ kwargs['quantization_config'] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4'
+ )
+ else:
+ kwargs['torch_dtype'] = torch.float16
+ if 'mplug_owl2' in model_name.lower():
+ # Load LLaVA model
+ if 'lora' in model_name.lower() and model_base is None:
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
+ if 'lora' in model_name.lower() and model_base is not None:
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ print('Loading mPLUG-Owl2 from base model...')
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
+ if model.lm_head.weight.shape[0] != token_num:
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+
+ print('Loading additional mPLUG-Owl2 weights...')
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ subfolder=subfolder)
+ return torch.load(cache_file, map_location='cpu')
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+ print('Loading LoRA weights...')
+ model = PeftModel.from_pretrained(model, model_path)
+ print('Merging LoRA weights...')
+ model = model.merge_and_unload()
+ print('Model is loaded...')
+ elif model_base is not None:
+ # this may be mm projector only
+ print('Loading mPLUG-Owl2 from base model...')
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print('Convert to FP16...')
+ model.to(torch.float16)
+ else:
+ use_fast = False
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+
+
+ vision_tower = model.get_model().vision_model
+ vision_tower.to(device=device, dtype=torch.float16)
+ image_processor = CLIPImageProcessor.from_pretrained(model_path)
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, image_processor, context_len
\ No newline at end of file
diff --git a/q_align/model/configuration_mplug_owl2.py b/q_align/model/configuration_mplug_owl2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2e31a61cc919a694ef62929c705d91143be63d1
--- /dev/null
+++ b/q_align/model/configuration_mplug_owl2.py
@@ -0,0 +1,332 @@
+# Copyright (c) Alibaba.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import copy
+import os
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+from transformers.utils import logging
+from transformers.models.auto import CONFIG_MAPPING
+
+
+class LlamaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the LLaMA-7B.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`LlamaModel`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
+ Llama 2 up to 4096, CodeLlama up to 16384.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ pad_token_id (`int`, *optional*):
+ Padding token id.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ Beginning of stream token id.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ End of stream token id.
+ pretraining_tp (`int`, *optional*, defaults to 1):
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
+ issue](https://github.com/pytorch/pytorch/issues/76232).
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`Dict`, *optional*):
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
+ these scaling strategies behave:
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
+ experimental feature, subject to breaking API changes in future versions.
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
+
+
+ ```python
+ >>> from transformers import LlamaModel, LlamaConfig
+
+ >>> # Initializing a LLaMA llama-7b style configuration
+ >>> configuration = LlamaConfig()
+
+ >>> # Initializing a model from the llama-7b style configuration
+ >>> model = LlamaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+ model_type = "llama"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act="silu",
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=None,
+ bos_token_id=1,
+ eos_token_id=2,
+ pretraining_tp=1,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ attention_bias=False,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ # for backward compatibility
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.pretraining_tp = pretraining_tp
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.attention_bias = attention_bias
+
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
+ f"got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
+
+
+class MplugOwlVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
+ a
+ mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
+ configuration defaults will yield a similar configuration to that of the mPLUG-Owl
+ [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 32):
+ The size (resolution) of each patch.
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
+ The epsilon used by the layer normalization layers.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 1):
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
+ testing).
+
+
+ ```"""
+
+ model_type = "mplug_owl_vision_model"
+
+ def __init__(
+ self,
+ hidden_size=1024,
+ intermediate_size=4096,
+ projection_dim=768,
+ num_hidden_layers=24,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=448,
+ patch_size=14,
+ hidden_act="quick_gelu",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ use_flash_attn=False,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.projection_dim = projection_dim
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.use_flash_attn = use_flash_attn
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from MplugOwlConfig
+ if config_dict.get("model_type") == "mplug-owl":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+class MplugOwlVisualAbstractorConfig(PretrainedConfig):
+ model_type = "mplug_owl_visual_abstract"
+
+ def __init__(
+ self,
+ num_learnable_queries=64,
+ hidden_size=1024,
+ num_hidden_layers=6,
+ num_attention_heads=16,
+ intermediate_size=2816,
+ attention_probs_dropout_prob=0.,
+ initializer_range=0.02,
+ layer_norm_eps=1e-6,
+ encoder_hidden_size=1024,
+ grid_size=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.num_learnable_queries = num_learnable_queries
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.encoder_hidden_size = encoder_hidden_size
+ self.grid_size = grid_size if grid_size else 32
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the visual_abstractor config dict if we are loading from MplugOwlConfig
+ if config_dict.get("model_type") == "mplug-owl":
+ config_dict = config_dict["abstractor_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+
+
+DEFAULT_VISUAL_CONFIG = {
+ "visual_model": MplugOwlVisionConfig().to_dict(),
+ "visual_abstractor": MplugOwlVisualAbstractorConfig().to_dict()
+}
+
+class MPLUGOwl2Config(LlamaConfig):
+ model_type = "mplug_owl2"
+ def __init__(self, visual_config=None, **kwargs):
+ if visual_config is None:
+ self.visual_config = DEFAULT_VISUAL_CONFIG
+ else:
+ self.visual_config = visual_config
+
+ super().__init__(
+ **kwargs,
+ )
+
+if __name__ == "__main__":
+ print(MplugOwlVisionConfig().to_dict())
\ No newline at end of file
diff --git a/q_align/model/convert_mplug_owl2_weight_to_hf.py b/q_align/model/convert_mplug_owl2_weight_to_hf.py
new file mode 100644
index 0000000000000000000000000000000000000000..8288a9a6a9b5d7a1a4ec58af6a53b14d4b266580
--- /dev/null
+++ b/q_align/model/convert_mplug_owl2_weight_to_hf.py
@@ -0,0 +1,395 @@
+# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import gc
+import json
+import math
+import os
+import shutil
+import warnings
+
+import torch
+
+from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
+from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
+from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
+
+try:
+ from transformers import LlamaTokenizerFast
+except ImportError as e:
+ warnings.warn(e)
+ warnings.warn(
+ "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
+ )
+ LlamaTokenizerFast = None
+
+"""
+Sample usage:
+
+```
+python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
+ --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
+```
+
+Thereafter, models can be loaded via:
+
+```py
+from transformers import LlamaForCausalLM, LlamaTokenizer
+
+model = LlamaForCausalLM.from_pretrained("/output/path")
+tokenizer = LlamaTokenizer.from_pretrained("/output/path")
+```
+
+Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
+come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
+"""
+
+llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
+llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
+llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
+ 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
+llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
+
+
+def compute_intermediate_size(n):
+ return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
+
+
+def read_json(path):
+ with open(path, "r") as f:
+ return json.load(f)
+
+
+def write_json(text, path):
+ with open(path, "w") as f:
+ json.dump(text, f)
+
+
+def write_model(model_path,
+ input_base_path,
+ model_size,
+ num_input_shards=1,
+ num_output_shards=2,
+ skip_permute=True,
+ norm_eps=1e-05):
+ # if os.path.exists(model_path):
+ # shutil.rmtree(model_path)
+ os.makedirs(model_path, exist_ok=True)
+ # tmp_model_path = os.path.join(model_path, "tmp")
+ tmp_model_path = model_path
+ os.makedirs(tmp_model_path, exist_ok=True)
+
+ num_shards = num_input_shards
+ n_layers = llama_s2layer[model_size]
+ n_heads = llama_s2heads[model_size]
+ n_heads_per_shard = n_heads // num_shards
+ n_dense = llama_s2dense[model_size]
+ n_hidden = llama_s2hidden[model_size]
+ hidden_per_head = n_hidden // n_heads
+ base = 10000.0
+ inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
+
+ # permute for sliced rotary
+ def permute(w, skip_permute=skip_permute):
+ if skip_permute:
+ return w
+ return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
+
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
+ # Load weights
+ if num_shards==1:
+ # Not sharded
+ # (The sharded implementation would also work, but this is simpler.)
+ # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
+ if os.path.exists(os.path.join(input_base_path, 'release')):
+ filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
+ elif input_base_path.split('/')[-1].startswith('iter_'):
+ iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
+ load_dir = '/'.join(input_base_path.split('/')[:-1])
+ filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
+ if not os.path.exists(filename):
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
+ else:
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
+ with open(tracker_filename, 'r') as f:
+ metastring = f.read().strip()
+ iteration = 'iter_{:07d}'.format(int(metastring))
+ filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
+ if not os.path.exists(filename):
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
+ original_filename = filename
+ loaded = torch.load(filename, map_location="cpu")['model']['language_model']
+
+ else:
+ # Sharded
+ filenames = []
+ for i in range(num_shards):
+ if os.path.exists(os.path.join(input_base_path, 'release')):
+ filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
+ else:
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
+ with open(tracker_filename, 'r') as f:
+ metastring = f.read().strip()
+ iteration = 'iter_{:07d}'.format(int(metastring))
+ filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
+ if not os.path.exists(filename):
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
+ filenames.append(filename)
+ loaded = [
+ torch.load(filenames[i], map_location="cpu")['model']['language_model']
+ for i in range(num_shards)
+ ]
+
+ print('Llama-Megatron Loaded!')
+ param_count = 0
+ index_dict = {"weight_map": {}}
+
+ print(f'Weighted Converting for {n_layers} layers...')
+ for layer_i in range(n_layers):
+ print(layer_i)
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
+ if num_shards == 1:
+ # Unsharded
+ state_dict = {
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
+ f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
+ f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
+ }
+ else:
+ raise NotImplemented
+# else:
+# # Sharded
+# # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
+# # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
+# # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
+
+# state_dict = {
+# f"model.layers.{layer_i}.input_layernorm.weight": loaded[0]['encoder'][
+# f"layers.{layer_i}.input_layernorm.multiway.0.weight"
+# ].clone(),
+# f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0]['encoder'][
+# f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"
+# ].clone(),
+# }
+
+# wqs, wks, wvs, ffn_w1s, ffn_w3s = [], [], [], [], []
+# for shard_idx in range(num_shards):
+# wqs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"])
+# wks.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"])
+# wvs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"])
+
+# state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
+# torch.cat(
+# [
+# wq.view(n_heads_per_shard, hidden_per_head, n_hidden)
+# for wq in range(wqs)
+# ],
+# dim=0,
+# ).reshape(n_hidden, n_hidden)
+# )
+# state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
+# torch.cat(
+# [
+# wk.view(n_heads_per_shard, hidden_per_head, n_hidden)
+# for wk in range(wks)
+# ],
+# dim=0,
+# ).reshape(n_hidden, n_hidden)
+# )
+# state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
+# [
+# wv.view(n_heads_per_shard, hidden_per_head, n_hidden)
+# for wv in range(wvs)
+# ],
+# dim=0,
+# ).reshape(n_hidden, n_hidden)
+
+# state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
+# [loaded[i]['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"] for i in range(num_shards)], dim=1
+# )
+# state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
+# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"] for i in range(num_shards)], dim=0
+# )
+# state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
+# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"] for i in range(num_shards)], dim=1
+# )
+# state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
+# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"] for i in range(num_shards)], dim=0
+# )
+
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
+ for k, v in state_dict.items():
+ index_dict["weight_map"][k] = filename
+ param_count += v.numel()
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
+ print(f'Sharded file saved to {filename}')
+
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
+ if num_shards==1:
+ # Unsharded
+ state_dict = {
+ "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
+ "model.norm.weight": loaded['encoder']['norm.weight'],
+ "lm_head.weight": loaded['encoder']['lm_head.weight'],
+ }
+ else:
+ state_dict = {
+ "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
+ "model.norm.weight": loaded[0]['encoder']['norm.weight'],
+ "lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
+ }
+
+
+ loaded_all = torch.load(original_filename, map_location="cpu")['model']
+ # Vision Part
+ state_dict.update({
+ "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
+ "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
+ "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
+ "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
+ "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
+ "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
+ "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
+ })
+ for v_layer_idx in range(24):
+ state_dict.update({
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
+ })
+
+ # Abstractor Part
+ state_dict.update({
+ "model.visual_abstractor.query_embeds": loaded_all['vision_abstractor']['learnable_queries'],
+ "model.visual_abstractor.visual_fc.bias": loaded_all['vision_abstractor']['visual_fc']['bias'],
+ "model.visual_abstractor.visual_fc.weight": loaded_all['vision_abstractor']['visual_fc']['weight'],
+ "model.visual_abstractor.vit_eos": loaded_all['vision_abstractor']['vit_eos'],
+ })
+ for v_layer_idx in range(6):
+ state_dict.update({
+ # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.k_pos_embed":
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.weight"],
+ # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.q_pos_embed": "pytorch_model-00004-of-00004.bin",
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.weight"],
+
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.weight"],
+
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.weight"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.bias"],
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.weight"],
+ })
+
+ for k, v in state_dict.items():
+ index_dict["weight_map"][k] = filename
+ param_count += v.numel()
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
+
+ # Write configs
+ index_dict["metadata"] = {"total_size": param_count * 2}
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
+
+ config = MPLUGOwl2Config()
+ config.save_pretrained(tmp_model_path)
+
+ # Make space so we can load the model properly now.
+ del state_dict
+ del loaded
+ del loaded_all
+ gc.collect()
+
+def write_tokenizer(tokenizer_path, input_tokenizer_path):
+ # Initialize the tokenizer based on the `spm` model
+ tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
+ print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
+ tokenizer = tokenizer_class(input_tokenizer_path)
+ tokenizer.save_pretrained(tokenizer_path)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input_dir",
+ help="Location of LLaMA_Megatron weights",
+ )
+ parser.add_argument(
+ "--model_size",
+ type=int,
+ default=7,
+ choices=[7, 13, 30, 65, 70],
+ )
+ parser.add_argument(
+ "--num_input_shards",
+ type=int,
+ default=1,
+ )
+ parser.add_argument(
+ "--num_output_shards",
+ type=int,
+ default=1,
+ )
+ parser.add_argument('--skip_permute', action='store_true')
+
+ parser.add_argument(
+ "--output_dir",
+ help="Location to write HF model and tokenizer",
+ )
+
+ args = parser.parse_args()
+ write_model(
+ model_path=args.output_dir,
+ input_base_path=args.input_dir,
+ model_size=args.model_size,
+ num_input_shards=args.num_input_shards,
+ num_output_shards=args.num_output_shards,
+ skip_permute=args.skip_permute
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/q_align/model/modeling_attn_mask_utils.py b/q_align/model/modeling_attn_mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2583a2dd5a09b1119c849ca00f954198d078799
--- /dev/null
+++ b/q_align/model/modeling_attn_mask_utils.py
@@ -0,0 +1,247 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+
+class AttentionMaskConverter:
+ """
+ A utility attention mask class that allows one to:
+ - Create a causal 4d mask
+ - Create a causal 4d mask with slided window
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
+ key_value_length) that can be multiplied with attention scores
+
+ Parameters:
+ is_causal (`bool`):
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
+
+ sliding_window (`int`, *optional*):
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
+ """
+
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
+ self.is_causal = is_causal
+ self.sliding_window = sliding_window
+
+ if self.sliding_window is not None and self.sliding_window <= 0:
+ raise ValueError(
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
+ )
+
+ def to_causal_4d(
+ self,
+ batch_size: int,
+ query_length: int,
+ key_value_length: int,
+ dtype: torch.dtype = torch.float32,
+ device: Union[torch.device, "str"] = "cpu",
+ ) -> torch.Tensor:
+ """
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
+ bias to upper right hand triangular matrix (causal mask).
+ """
+ if not self.is_causal:
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
+
+ # If shape is not cached, create a new causal mask and cache it
+ input_shape = (batch_size, query_length)
+ past_key_values_length = key_value_length - query_length
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if input_shape[-1] > 1 or self.sliding_window is not None:
+ causal_4d_mask = self._make_causal_mask(
+ input_shape,
+ dtype,
+ device=device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=self.sliding_window,
+ )
+
+ return causal_4d_mask
+
+ def to_4d(
+ self,
+ attention_mask_2d: torch.Tensor,
+ query_length: int,
+ key_value_length: Optional[int] = None,
+ dtype: torch.dtype = torch.float32,
+ ) -> torch.Tensor:
+ """
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
+ causal, a causal mask will be added.
+ """
+ input_shape = (attention_mask_2d.shape[0], query_length)
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ causal_4d_mask = None
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
+ if key_value_length is None:
+ raise ValueError(
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
+ )
+
+ past_key_values_length = key_value_length - query_length
+ causal_4d_mask = self._make_causal_mask(
+ input_shape,
+ dtype,
+ device=attention_mask_2d.device,
+ past_key_values_length=past_key_values_length,
+ sliding_window=self.sliding_window,
+ )
+ elif self.sliding_window is not None:
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
+
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
+ attention_mask_2d.device
+ )
+ expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
+
+ return expanded_4d_mask
+
+ @staticmethod
+ def _make_causal_mask(
+ input_ids_shape: torch.Size,
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+ ):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+ # add lower triangular sliding window mask if necessary
+ if sliding_window is not None:
+ diagonal = past_key_values_length - sliding_window + 1
+
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
+
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+ @staticmethod
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+def _prepare_4d_causal_attention_mask(
+ attention_mask: Optional[torch.Tensor],
+ input_shape: Union[torch.Size, Tuple, List],
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ attention_mask (`torch.Tensor` or `None`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ inputs_embeds (`torch.Tensor`):
+ The embedded inputs as a torch Tensor.
+ past_key_values_length (`int`):
+ The length of the key value cache.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = input_shape[-1] + past_key_values_length
+
+ # 4d mask is passed through the layers
+ if attention_mask is not None:
+ attention_mask = attn_mask_converter.to_4d(
+ attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
+ )
+ else:
+ attention_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+
+ return attention_mask
+
+
+def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`
+
+ Args:
+ mask (`torch.Tensor` or `None`):
+ A 2D attention mask of shape `(batch_size, key_value_length)`
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ tgt_len (`int`):
+ The target length or query length the created mask shall have.
+ """
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _create_4d_causal_attention_mask(
+ input_shape: Union[torch.Size, Tuple, List],
+ dtype: torch.dtype,
+ device: torch.device,
+ past_key_values_length: int = 0,
+ sliding_window: Optional[int] = None,
+):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
+
+ Args:
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
+ dtype (`torch.dtype`):
+ The torch dtype the created mask shall have.
+ device (`int`):
+ The torch device the created mask shall have.
+ sliding_window (`int`, *optional*):
+ If the model uses windowed attention, a sliding window should be passed.
+ """
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+ key_value_length = past_key_values_length + input_shape[-1]
+ attention_mask = attn_mask_converter.to_causal_4d(
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
+ )
+
+ return attention_mask
\ No newline at end of file
diff --git a/q_align/model/modeling_llama2.py b/q_align/model/modeling_llama2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dbc5738e4b456bfee4035c3cc1bba7576118e91
--- /dev/null
+++ b/q_align/model/modeling_llama2.py
@@ -0,0 +1,486 @@
+import math
+import warnings
+from functools import partial
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+import transformers
+from transformers.models.llama.modeling_llama import *
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from .configuration_mplug_owl2 import LlamaConfig
+
+class MultiwayNetwork(nn.Module):
+
+ def __init__(self, module_provider, num_multiway=2):
+ super(MultiwayNetwork, self).__init__()
+
+ self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
+
+ def forward(self, hidden_states, multiway_indices):
+
+ if len(self.multiway) == 1:
+ return self.multiway[0](hidden_states)
+
+ output_hidden_states = torch.empty_like(hidden_states)
+
+ for idx, subway in enumerate(self.multiway):
+ local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
+ hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
+ if hidden.numel():
+ output = subway(hidden)
+ if isinstance(output, tuple):
+ output = output[0]
+ output = output.squeeze(1)
+ output_hidden_states[local_indices] = output
+
+ return output_hidden_states.contiguous()
+
+
+class LlamaAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = config.rope_theta
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+ self.k_proj = MultiwayNetwork(module_provider=partial(
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ )
+ self.v_proj = MultiwayNetwork(module_provider=partial(
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+ )
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ scaling_factor = self.config.rope_scaling["factor"]
+ if scaling_type == "linear":
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ elif scaling_type == "dynamic":
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ scaling_factor=scaling_factor,
+ base=self.rope_theta,
+ )
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ modality_indicators: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states, )
+ key_states = self.k_proj(hidden_states, modality_indicators)
+ value_states = self.v_proj(hidden_states, modality_indicators)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+
+class LlamaDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LlamaAttention(config=config)
+ self.mlp = LlamaMLP(config)
+ self.input_layernorm = MultiwayNetwork(module_provider=partial(
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
+ ))
+ self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
+ ))
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ modality_indicators: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states, modality_indicators)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ modality_indicators=modality_indicators,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+def model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ modality_indicators: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ modality_indicators,
+ attention_mask,
+ position_ids,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ modality_indicators=modality_indicators,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+def causal_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ modality_indicators: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ modality_indicators=modality_indicators,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if self.config.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
+ logits = torch.cat(logits, dim=-1)
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+def replace_llama_modality_adaptive():
+ transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
+ transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
+ transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
+
+
+if __name__ == "__main__":
+ replace_llama_modality_adaptive()
+ config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
+ model = transformers.LlamaForCausalLM(config)
+ print(model)
\ No newline at end of file
diff --git a/q_align/model/modeling_mplug_owl2.py b/q_align/model/modeling_mplug_owl2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7154a67c056cfa4cdd5033b5ce46c5b62bf3cbb
--- /dev/null
+++ b/q_align/model/modeling_mplug_owl2.py
@@ -0,0 +1,329 @@
+# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
+from .visual_encoder import MplugOwlVisionModel, MplugOwlVisualAbstractorModel
+from .modeling_llama2 import replace_llama_modality_adaptive
+from q_align.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
+from icecream import ic
+
+class MPLUGOwl2MetaModel:
+ def __init__(self, config):
+ super(MPLUGOwl2MetaModel, self).__init__(config)
+ self.vision_model = MplugOwlVisionModel(
+ MplugOwlVisionConfig(**config.visual_config["visual_model"])
+ )
+ self.visual_abstractor = MplugOwlVisualAbstractorModel(
+ MplugOwlVisualAbstractorConfig(**config.visual_config["visual_abstractor"]), config.hidden_size
+ )
+
+ def get_vision_tower(self):
+ vision_model = getattr(self, 'vision_model', None)
+ if type(vision_model) is list:
+ vision_model = vision_model[0]
+ return vision_model
+
+ def get_visual_abstractor(self):
+ visual_abstractor = getattr(self, 'visual_abstractor', None)
+ if type(visual_abstractor) is list:
+ visual_abstractor = visual_abstractor[0]
+ return visual_abstractor
+
+
+class MPLUGOwl2MetaForCausalLM(ABC):
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def encode_images(self, images):
+ image_features = self.get_model().vision_model(images).last_hidden_state
+ image_features = self.get_model().visual_abstractor(encoder_hidden_states=image_features).last_hidden_state
+ return image_features
+
+ def prepare_inputs_labels_for_multimodal(
+ self, input_ids, attention_mask, past_key_values, labels, images
+ ):
+ if images is None or input_ids.shape[1] == 1:
+ if past_key_values is not None and images is not None and input_ids.shape[1] == 1:
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
+ multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
+ return input_ids, multiway_indices, attention_mask, past_key_values, None, labels
+
+ if type(images) is list or images.ndim == 5:
+ concat_images = torch.cat([image for image in images], dim=0)
+ image_features = self.encode_images(concat_images)
+ split_sizes = [image.shape[0] for image in images]
+ image_features = torch.split(image_features, split_sizes, dim=0)
+ image_features = [x.flatten(0, 1) for x in image_features]
+ else:
+ image_features = self.encode_images(images)
+
+ new_input_embeds = []
+ new_modality_indicators = []
+ new_labels = [] if labels is not None else None
+ cur_image_idx = 0
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
+ # multimodal LLM, but the current sample is not multimodal
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
+ half_len = cur_input_ids.shape[0] // 2
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
+ new_input_embeds.append(cur_input_embeds)
+
+ cur_modality_indicators = torch.zeros(len(cur_input_embeds)).long().to(self.device)
+ new_modality_indicators.append(cur_modality_indicators)
+ if labels is not None:
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
+ cur_new_input_embeds = []
+ cur_modality_indicators = []
+ if labels is not None:
+ cur_labels = labels[batch_idx]
+ cur_new_labels = []
+ assert cur_labels.shape == cur_input_ids.shape
+ while image_token_indices.numel() > 0:
+ cur_image_features = image_features[cur_image_idx]
+ image_token_start = image_token_indices[0]
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
+ cur_new_input_embeds.append(cur_image_features)
+
+ # Add modality indicator
+ assert image_token_start == len(cur_input_ids[:image_token_start])
+ cur_modality_indicators.append(torch.zeros(len(cur_input_ids[:image_token_start])).long())
+ cur_modality_indicators.append(torch.ones(len(cur_image_features)).long())
+
+ if labels is not None:
+ cur_new_labels.append(cur_labels[:image_token_start])
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
+ cur_labels = cur_labels[image_token_start+1:]
+ cur_image_idx += 1
+ cur_input_ids = cur_input_ids[image_token_start+1:]
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
+ if cur_input_ids.numel() > 0:
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
+ cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long())
+ if labels is not None:
+ cur_new_labels.append(cur_labels)
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
+ new_input_embeds.append(cur_new_input_embeds)
+
+ # Modality
+ cur_modality_indicators = [x.to(device=self.device) for x in cur_modality_indicators]
+ cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0)
+ new_modality_indicators.append(cur_modality_indicators)
+
+
+ if labels is not None:
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
+ new_labels.append(cur_new_labels)
+
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
+ max_len = max(x.shape[0] for x in new_input_embeds)
+
+ # Embedding
+ new_input_embeds_align = []
+ for cur_new_embed in new_input_embeds:
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
+ new_input_embeds_align.append(cur_new_embed)
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
+
+ # Modality
+ new_modality_indicators_align = []
+ for cur_modality_indicator in new_modality_indicators:
+ cur_new_embed = torch.cat((cur_modality_indicator, torch.zeros(max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device)), dim=0)
+ new_modality_indicators_align.append(cur_new_embed)
+ new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0)
+
+ # Label
+ if labels is not None:
+ new_labels_align = []
+ _new_labels = new_labels
+ for cur_new_label in new_labels:
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
+ new_labels_align.append(cur_new_label)
+ new_labels = torch.stack(new_labels_align, dim=0)
+
+ # Attention Mask
+ if attention_mask is not None:
+ new_attention_mask = []
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
+ new_attention_mask.append(cur_new_attention_mask)
+ attention_mask = torch.stack(new_attention_mask, dim=0)
+ assert attention_mask.shape == new_labels.shape
+ else:
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
+ new_modality_indicators = torch.stack(new_modality_indicators, dim=0)
+ if labels is not None:
+ new_labels = torch.stack(new_labels, dim=0)
+
+ if attention_mask is not None:
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
+ assert attention_mask.shape == new_input_embeds.shape[:2]
+ return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels
+
+
+
+class MPLUGOwl2LlamaModel(MPLUGOwl2MetaModel, LlamaModel):
+ config_class = MPLUGOwl2Config
+
+ def __init__(self, config: MPLUGOwl2Config):
+ super(MPLUGOwl2LlamaModel, self).__init__(config)
+
+
+class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
+ config_class = MPLUGOwl2Config
+
+ def __init__(self, config):
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = MPLUGOwl2LlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ # modality_indicators: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \
+ self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ modality_indicators=modality_indicators,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model/pipeline parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "images": kwargs.get("images", None),
+ }
+ )
+ return model_inputs
+
+AutoConfig.register("mplug_owl2", MPLUGOwl2Config)
+AutoModelForCausalLM.register(MPLUGOwl2Config, MPLUGOwl2LlamaForCausalLM)
+
+replace_llama_modality_adaptive()
+
+if __name__ == "__main__":
+ config = MPLUGOwl2Config.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
+ from icecream import ic
+ # config = MPLUGOwl2Config()
+ model = MPLUGOwl2LlamaForCausalLM(config)
+
+ images = torch.randn(2, 3, 448, 448)
+ input_ids = torch.cat([
+ torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long(), torch.tensor([-1]*1).long(), torch.ones(8).long()
+ ], dim=0).unsqueeze(0)
+ labels = input_ids.clone()
+ labels[labels < 0] = -100
+
+ # image_feature = model.encode_images(images)
+ # ic(image_feature.shape)
+
+ output = model(images=images, input_ids=input_ids, labels=labels)
+ ic(output.loss)
+ ic(output.logits.shape)
+
+ model.save_pretrained('/cpfs01/shared/public/test/tmp_owl')
\ No newline at end of file
diff --git a/q_align/model/utils.py b/q_align/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a47edb191f0f6587827cfab6a731ec106a523b0
--- /dev/null
+++ b/q_align/model/utils.py
@@ -0,0 +1,20 @@
+from transformers import AutoConfig
+
+
+def auto_upgrade(config):
+ cfg = AutoConfig.from_pretrained(config)
+ if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type:
+ assert cfg.model_type == 'mplug_owl2'
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
+ if confirm.lower() in ["y", "yes"]:
+ print("Upgrading checkpoint...")
+ assert len(cfg.architectures) == 1
+ setattr(cfg.__class__, "model_type", "mplug_owl2")
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
+ cfg.save_pretrained(config)
+ print("Checkpoint upgraded.")
+ else:
+ print("Checkpoint upgrade aborted.")
+ exit(1)
\ No newline at end of file
diff --git a/q_align/model/visual_encoder.py b/q_align/model/visual_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..39f76207dd9a25f870c1cf35f7fbcba2bf342e0c
--- /dev/null
+++ b/q_align/model/visual_encoder.py
@@ -0,0 +1,928 @@
+import math
+from typing import Any, Optional, Tuple, Union
+
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from icecream import ic
+
+def get_abs_pos(abs_pos, tgt_size):
+ # abs_pos: L, C
+ # tgt_size: M
+ # return: M, C
+ src_size = int(math.sqrt(abs_pos.size(0)))
+ tgt_size = int(math.sqrt(tgt_size))
+ dtype = abs_pos.dtype
+
+ if src_size != tgt_size:
+ return F.interpolate(
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
+ size=(tgt_size, tgt_size),
+ mode="bicubic",
+ align_corners=False,
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
+ else:
+ return abs_pos
+
+# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+
+class MplugOwlVisionEmbeddings(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
+
+ self.patch_embed = nn.Conv2d(
+ in_channels=3,
+ out_channels=self.hidden_size,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ bias=False,
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
+
+ self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ batch_size = pixel_values.size(0)
+ image_embeds = self.patch_embed(pixel_values)
+ image_embeds = image_embeds.flatten(2).transpose(1, 2)
+
+ class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
+ embeddings = torch.cat([class_embeds, image_embeds], dim=1)
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
+ embeddings = self.pre_layernorm(embeddings)
+ return embeddings
+
+
+
+class MplugOwlVisionAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ if self.head_dim * self.num_heads != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = nn.Dropout(config.attention_dropout)
+
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size)
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, seq_len, embed_dim = hidden_states.size()
+
+ mixed_qkv = self.query_key_value(hidden_states)
+
+ mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute(
+ 3, 0, 2, 1, 4
+ ) # [3, b, np, sq, hn]
+ query_states, key_states, value_states = (
+ mixed_qkv[0],
+ mixed_qkv[1],
+ mixed_qkv[2],
+ )
+ # if self.config.use_flash_attn and flash_attn_func is not None:
+ if False:
+ # [b*sq, np, hn]
+ query_states = query_states.permute(0, 2, 1, 3).contiguous()
+ query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1)
+
+ key_states = key_states.permute(0, 2, 1, 3).contiguous()
+ key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1)
+
+ value_states = value_states.permute(0, 2, 1, 3).contiguous()
+ value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1)
+
+ cu_seqlens = torch.arange(
+ 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device
+ )
+
+ context_layer = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens,
+ cu_seqlens,
+ seq_len,
+ seq_len,
+ self.dropout if self.training else 0.0,
+ softmax_scale=self.scale,
+ causal=False,
+ return_attn_probs=False,
+ )
+ # [b*sq, np, hn] => [b, sq, np, hn]
+ context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2))
+ else:
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
+
+ attention_scores = attention_scores * self.scale
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = torch.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ output = self.dense(context_layer)
+
+ outputs = (output, attention_probs) if output_attentions else (output, None)
+
+ return outputs
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class MplugOwlMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = QuickGELU()
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class MplugOwlVisionEncoderLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = MplugOwlVisionAttention(config)
+ self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = MplugOwlMLP(config)
+ self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ `(config.encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ head_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = hidden_states + residual
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+
+ hidden_states = hidden_states + residual
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class MplugOwlVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`MplugOwlVisionEncoderLayer`].
+
+ Args:
+ config (`MplugOwlVisionConfig`):
+ The corresponding vision configuration for the `MplugOwlEncoder`.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = True
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Embedded representation of the inputs. Should be float, not int tokens.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class MplugOwlVisionModel(PreTrainedModel):
+ main_input_name = "pixel_values"
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+ self.hidden_size = config.hidden_size
+
+ self.embeddings = MplugOwlVisionEmbeddings(config)
+ self.encoder = MplugOwlVisionEncoder(config)
+ self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
+
+ self.post_init()
+
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = last_hidden_state[:, 0, :]
+ pooled_output = self.post_layernorm(pooled_output)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+
+class MplugOwlVisualAbstractorMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ in_features = config.hidden_size
+ self.act = nn.SiLU()
+
+ self.w1 = nn.Linear(in_features, config.intermediate_size)
+ self.w2 = nn.Linear(config.intermediate_size, in_features)
+ self.w3 = nn.Linear(in_features, config.intermediate_size)
+ self.ffn_ln = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.act(self.w1(hidden_states)) * self.w3(hidden_states)
+ hidden_states = self.ffn_ln(hidden_states)
+ hidden_states = self.w2(hidden_states)
+ return hidden_states
+
+
+class MplugOwlVisualAbstractorMultiHeadAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
+ % (config.hidden_size, config.num_attention_heads)
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.save_attention = False
+
+# self.q_pos_embed = nn.Parameter(
+# torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float()
+# ).requires_grad_(False)
+# grids = config.grid_size
+# self.k_pos_embed = nn.Parameter(
+# torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float()
+# ).requires_grad_(False)
+ grids = config.grid_size
+ self.register_buffer(
+ 'q_pos_embed',
+ torch.from_numpy(get_1d_sincos_pos_embed_from_grid(config.hidden_size, np.arange(config.num_learnable_queries, dtype=np.float32))).float()
+ )
+ self.register_buffer(
+ 'k_pos_embed',
+ torch.from_numpy(get_2d_sincos_pos_embed(config.hidden_size, grids, cls_token=True)).float()
+ )
+
+
+ def save_attn_gradients(self, attn_gradients):
+ self.attn_gradients = attn_gradients
+
+ def get_attn_gradients(self):
+ return self.attn_gradients
+
+ def save_attention_map(self, attention_map):
+ self.attention_map = attention_map
+
+ def get_attention_map(self):
+ return self.attention_map
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_value=None,
+ output_attentions=False,
+ ):
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+
+ qk_pos_embed = torch.cat([self.q_pos_embed, self.k_pos_embed], dim = 0).unsqueeze(0).to(dtype=hidden_states.dtype)
+
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states + qk_pos_embed))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+
+ mixed_query_layer = self.query(hidden_states + self.q_pos_embed.unsqueeze(0).to(dtype=hidden_states.dtype))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ if self.save_attention:
+ self.save_attention_map(attention_probs)
+ attention_probs.register_hook(self.save_attn_gradients)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs_dropped = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs_dropped = attention_probs_dropped * head_mask
+
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class MplugOwlVisualAbstractorCrossOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ dim = config.hidden_size
+ self.out_proj = nn.Linear(dim, dim, bias=True)
+ self.norm2 = nn.LayerNorm(dim)
+ self.mlp = MplugOwlVisualAbstractorMLP(config)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ input_tensor = input_tensor + self.out_proj(hidden_states)
+ input_tensor = input_tensor + self.mlp(self.norm2(input_tensor))
+ return input_tensor
+
+
+class MplugOwlVisualAbstractorAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = MplugOwlVisualAbstractorMultiHeadAttention(config)
+ self.output = MplugOwlVisualAbstractorCrossOutput(config)
+ self.pruned_heads = set()
+ self.norm1 = nn.LayerNorm(config.hidden_size)
+ self.normk = nn.LayerNorm(config.hidden_size)
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.out_proj, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # HACK we apply norm on q and k
+ hidden_states = self.norm1(hidden_states)
+ encoder_hidden_states = self.normk(encoder_hidden_states)
+ encoder_hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+ encoder_attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=-1)
+ self_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ # add attentions if we output them
+ outputs = (attention_output,) + self_outputs[1:]
+ return outputs
+
+
+class MplugOwlVisualAbstractorLayer(nn.Module):
+ def __init__(self, config, layer_idx):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+
+ self.layer_idx = layer_idx
+
+ self.crossattention = MplugOwlVisualAbstractorAttention(config)
+ self.has_cross_attention = True
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ output_attentions=False,
+ ):
+ if encoder_hidden_states is None:
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
+ cross_attention_outputs = self.crossattention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions=output_attentions,
+ )
+ query_attention_output = cross_attention_outputs[0]
+
+ outputs = (query_attention_output,)
+ return outputs
+
+
+class MplugOwlVisualAbstractorEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList(
+ [MplugOwlVisualAbstractorLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.gradient_checkpointing = True
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+
+ for i in range(self.config.num_hidden_layers):
+ layer_module = self.layers[i]
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ )
+
+
+class MplugOwlVisualAbstractorModel(PreTrainedModel):
+ def __init__(self, config, language_hidden_size):
+ super().__init__(config)
+ self.config = config
+
+ self.encoder = MplugOwlVisualAbstractorEncoder(config)
+ self.visual_fc = torch.nn.Linear(config.hidden_size, language_hidden_size)
+ self.query_embeds = torch.nn.Parameter(torch.randn(1, config.num_learnable_queries, config.hidden_size))
+ self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
+
+ self.post_init()
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def get_extended_attention_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_shape: Tuple[int],
+ device: torch.device,
+ ) -> torch.Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (`Tuple[int]`):
+ The shape of the input to the model.
+ device: (`torch.device`):
+ The device of the input to the model.
+
+ Returns:
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+ """
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.dim() == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.dim() == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+ input_shape, attention_mask.shape
+ )
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+ return extended_attention_mask
+
+ def forward(
+ self,
+ attention_mask=None,
+ head_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ past_key_values=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
+ `(batch_size, sequence_length)`.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ query_embeds = self.query_embeds.repeat(encoder_hidden_states.shape[0], 1, 1)
+ embedding_output = query_embeds
+ input_shape = embedding_output.size()[:-1]
+ batch_size, seq_length = input_shape
+ device = embedding_output.device
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (query_embeds.shape[0], query_embeds.shape[1]), dtype=torch.long, device=query_embeds.device
+ )
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if encoder_hidden_states is not None:
+ if type(encoder_hidden_states) == list:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
+ else:
+ (
+ encoder_batch_size,
+ encoder_sequence_length,
+ _,
+ ) = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+ if type(encoder_attention_mask) == list:
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
+ elif encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = sequence_output[:, 0, :]
+
+ sequence_output = self.visual_fc(sequence_output)
+ sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(sequence_output.shape[0], 1, 1)], dim=1)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ )
+
+
+if __name__ == "__main__":
+ from configuration_mplug_owl2 import MPLUGOwl2Config
+ config = MPLUGOwl2Config()
+ visual_model = MplugOwlVisionModel(config.visual_config["visual_model"])
+ print(visual_model)
+
+ abstractor_module = MplugOwlVisualAbstractorModel(config.visual_config["visual_abstractor"], config.hidden_size)
+ print(abstractor_module)
\ No newline at end of file
diff --git a/q_align/train/.ipynb_checkpoints/train-checkpoint.py b/q_align/train/.ipynb_checkpoints/train-checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..d109848b1664625fdf099c50e210e1a93a909002
--- /dev/null
+++ b/q_align/train/.ipynb_checkpoints/train-checkpoint.py
@@ -0,0 +1,844 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+import torch
+
+import transformers
+from transformers.models.clip.image_processing_clip import CLIPImageProcessor
+
+from torch.utils.data import Dataset
+from q_align.train.mplug_owl2_trainer import MPLUGOwl2Trainer
+from q_align.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+
+from q_align import conversation as conversation_lib
+from q_align.model import *
+from q_align.mm_utils import tokenizer_image_token
+
+from PIL import Image
+from icecream import ic
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None,
+ metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ image_folder: Optional[str] = field(default=None)
+ image_aspect_ratio: str = 'square'
+ image_grid_pinpoints: Optional[str] = field(default=None)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+
+ tune_visual_abstractor: bool = field(default=True)
+ freeze_vision_model: bool = field(default=True)
+
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ visual_abstractor_lr: Optional[float] = None
+ group_by_modality_length: bool = field(default=False)
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ['vision_model', 'visual_abstractor']
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ lora_module_names.add(name)
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+ data_args: DataArguments
+) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments):
+ super(LazySupervisedDataset, self).__init__()
+ list_data_dict = json.load(open(data_path, "r"))
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if 'image' in sample else 0
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
+ return length_list
+
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
+ cur_len = cur_len if 'image' in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+# def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+# sources = self.list_data_dict[i]
+# if isinstance(i, int):
+# sources = [sources]
+# assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+# if 'image' in sources[0]:
+# image_file = self.list_data_dict[i]['image']
+# image_folder = self.data_args.image_folder
+# processor = self.data_args.image_processor
+# image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
+# if self.data_args.image_aspect_ratio == 'pad':
+# def expand2square(pil_img, background_color):
+# width, height = pil_img.size
+# if width == height:
+# return pil_img
+# elif width > height:
+# result = Image.new(pil_img.mode, (width, width), background_color)
+# result.paste(pil_img, (0, (width - height) // 2))
+# return result
+# else:
+# result = Image.new(pil_img.mode, (height, height), background_color)
+# result.paste(pil_img, ((height - width) // 2, 0))
+# return result
+# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+# else:
+# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+# sources = preprocess_multimodal(
+# copy.deepcopy([e["conversations"] for e in sources]),
+# self.data_args)
+# else:
+# sources = copy.deepcopy([e["conversations"] for e in sources])
+# data_dict = preprocess(
+# sources,
+# self.tokenizer,
+# has_image=('image' in self.list_data_dict[i]))
+# if isinstance(i, int):
+# data_dict = dict(input_ids=data_dict["input_ids"][0],
+# labels=data_dict["labels"][0])
+
+# # image exist in the data
+# if 'image' in self.list_data_dict[i]:
+# data_dict['image'] = image
+# elif self.data_args.is_multimodal:
+# # image does not exist in the data, but the model is multimodal
+# crop_size = self.data_args.image_processor.crop_size
+# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+# return data_dict
+
+ def next_rand(self):
+ import random
+ return random.randint(0,len(self)-1)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ while True:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ if 'image' in sources[0]:
+ image_file = self.list_data_dict[i]['image']
+
+ image_folder = self.data_args.image_folder
+ processor = self.data_args.image_processor
+ from pathlib import Path
+ #if not Path(os.path.join(image_folder, image_file)).exists():
+ # i = self.next_rand()
+ # continue
+ if isinstance(image_file, list):
+ # Multiple Images as Input
+ try:
+ image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file]
+ except Exception as ex:
+ print(ex)
+ i = self.next_rand()
+ continue
+ if self.data_args.image_aspect_ratio == 'pad':
+ image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ elif os.path.join(image_folder, image_file).endswith("mp4"):
+ # Video as Input
+ image = load_video(os.path.join(image_folder, image_file))
+ if self.data_args.image_aspect_ratio == 'pad':
+ image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ else:
+ try:
+ image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
+ except Exception as ex:
+ print(ex)
+ i = self.next_rand()
+ continue
+ if self.data_args.image_aspect_ratio == 'pad':
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ sources = preprocess_multimodal(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.data_args)
+ else:
+
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=('image' in self.list_data_dict[i]))
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if 'image' in self.list_data_dict[i]:
+ data_dict['image'] = image
+ elif self.data_args.is_multimodal:
+ # image does not exist in the data, but the model is multimodal
+ crop_size = self.data_args.image_processor.crop_size
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ labels = labels[:, :self.tokenizer.model_max_length]
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ data_args=data_args)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator)
+
+
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if not training_args.freeze_vision_model and training_args.bits in [4, 8]:
+ model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device)
+ else:
+ vision_tower = model.get_model().vision_model
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ if training_args.tune_visual_abstractor and training_args.bits in [4, 8]:
+ model.get_model().visual_abstractor.to(dtype=compute_dtype, device=training_args.device)
+ else:
+ visual_abstractor = model.get_model().visual_abstractor
+ visual_abstractor.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = CLIPImageProcessor.from_pretrained(model_args.model_name_or_path)
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+ model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = training_args.tune_visual_abstractor
+ ic(training_args.tune_visual_abstractor)
+ model.requires_grad_(True)
+ if training_args.tune_visual_abstractor:
+ # model.requires_grad_(False)
+ for p in model.get_model().visual_abstractor.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_vision_model = training_args.freeze_vision_model
+ ic(training_args.freeze_vision_model)
+ if training_args.freeze_vision_model:
+ for p in model.get_model().vision_model.parameters():
+ p.requires_grad = False
+
+ model.config.visual_abstractor_lr = training_args.visual_abstractor_lr
+
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+ trainer = MPLUGOwl2Trainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ # trainer.train(resume_from_checkpoint=True)
+ # else:
+ # trainer.train()
+
+ # TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE
+ trainer.train()
+
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
\ No newline at end of file
diff --git a/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..526270d58a01315bbdac70f4465e54e94d0335c9
Binary files /dev/null and b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-310.pyc differ
diff --git a/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-311.pyc b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d2b77240124a91ac5f378a6b3997814f71b03ae
Binary files /dev/null and b/q_align/train/__pycache__/llama_flash_attn_monkey_patch.cpython-311.pyc differ
diff --git a/q_align/train/__pycache__/mplug_owl2_trainer.cpython-310.pyc b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cb851af86edcd4279750449e470b5ed757f54d38
Binary files /dev/null and b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-310.pyc differ
diff --git a/q_align/train/__pycache__/mplug_owl2_trainer.cpython-311.pyc b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9d7423b8c38926885b9376550027174522827c0
Binary files /dev/null and b/q_align/train/__pycache__/mplug_owl2_trainer.cpython-311.pyc differ
diff --git a/q_align/train/__pycache__/train.cpython-310.pyc b/q_align/train/__pycache__/train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1cc4f452551034733cf2096054fa152b0f490de0
Binary files /dev/null and b/q_align/train/__pycache__/train.cpython-310.pyc differ
diff --git a/q_align/train/__pycache__/train.cpython-311.pyc b/q_align/train/__pycache__/train.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5aa2ed03477cf21175212b9c3dbdf46ee438b267
Binary files /dev/null and b/q_align/train/__pycache__/train.cpython-311.pyc differ
diff --git a/q_align/train/llama_flash_attn_monkey_patch.py b/q_align/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..23bfdd8920b25581956356ca881f83070e2ae111
--- /dev/null
+++ b/q_align/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,117 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ modality_indicators: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: bool = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states, modality_indicators)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states, modality_indicators)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ ) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
\ No newline at end of file
diff --git a/q_align/train/mplug_owl2_trainer.py b/q_align/train/mplug_owl2_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..293dcdf21c82c715e663af11b90a19543244af18
--- /dev/null
+++ b/q_align/train/mplug_owl2_trainer.py
@@ -0,0 +1,243 @@
+import os
+import torch
+
+from torch.utils.data import Sampler
+
+from transformers import Trainer
+from transformers.trainer import (
+ is_sagemaker_mp_enabled,
+ get_parameter_names,
+ has_length,
+ ALL_LAYERNORM_LAYERS,
+ ShardedDDPOption,
+ logger,
+)
+from typing import List, Optional
+from icecream import ic
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, 'no ignore status')
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) > 0:
+ megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ group_by_modality: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class MPLUGOwl2Trainer(Trainer):
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ self.args.train_batch_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ if self.args.visual_abstractor_lr is not None:
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "visual_abstractor_lr" in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ "lr": self.args.visual_abstractor_lr,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ "lr": self.args.visual_abstractor_lr,
+ },
+ ]
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+ ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params']))
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer = OSS(
+ params=optimizer_grouped_parameters,
+ optim=optimizer_cls,
+ **optimizer_kwargs,
+ )
+ else:
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ super(MPLUGOwl2Trainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ super(MPLUGOwl2Trainer, self)._save(output_dir, state_dict)
\ No newline at end of file
diff --git a/q_align/train/train.py b/q_align/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d109848b1664625fdf099c50e210e1a93a909002
--- /dev/null
+++ b/q_align/train/train.py
@@ -0,0 +1,844 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+
+from PIL import ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+import torch
+
+import transformers
+from transformers.models.clip.image_processing_clip import CLIPImageProcessor
+
+from torch.utils.data import Dataset
+from q_align.train.mplug_owl2_trainer import MPLUGOwl2Trainer
+from q_align.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+
+from q_align import conversation as conversation_lib
+from q_align.model import *
+from q_align.mm_utils import tokenizer_image_token
+
+from PIL import Image
+from icecream import ic
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None,
+ metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ image_folder: Optional[str] = field(default=None)
+ image_aspect_ratio: str = 'square'
+ image_grid_pinpoints: Optional[str] = field(default=None)
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+
+ tune_visual_abstractor: bool = field(default=True)
+ freeze_vision_model: bool = field(default=True)
+
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ visual_abstractor_lr: Optional[float] = None
+ group_by_modality_length: bool = field(default=False)
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ['vision_model', 'visual_abstractor']
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ lora_module_names.add(name)
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+ data_args: DataArguments
+) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def load_video(video_file):
+ from decord import VideoReader
+ vr = VideoReader(video_file)
+
+ # Get video frame rate
+ fps = vr.get_avg_fps()
+
+ # Calculate frame indices for 1fps
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
+ frames = vr.get_batch(frame_indices).asnumpy()
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments):
+ super(LazySupervisedDataset, self).__init__()
+ list_data_dict = json.load(open(data_path, "r"))
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if 'image' in sample else 0
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
+ return length_list
+
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
+ cur_len = cur_len if 'image' in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+# def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+# sources = self.list_data_dict[i]
+# if isinstance(i, int):
+# sources = [sources]
+# assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+# if 'image' in sources[0]:
+# image_file = self.list_data_dict[i]['image']
+# image_folder = self.data_args.image_folder
+# processor = self.data_args.image_processor
+# image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
+# if self.data_args.image_aspect_ratio == 'pad':
+# def expand2square(pil_img, background_color):
+# width, height = pil_img.size
+# if width == height:
+# return pil_img
+# elif width > height:
+# result = Image.new(pil_img.mode, (width, width), background_color)
+# result.paste(pil_img, (0, (width - height) // 2))
+# return result
+# else:
+# result = Image.new(pil_img.mode, (height, height), background_color)
+# result.paste(pil_img, ((height - width) // 2, 0))
+# return result
+# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+# else:
+# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+# sources = preprocess_multimodal(
+# copy.deepcopy([e["conversations"] for e in sources]),
+# self.data_args)
+# else:
+# sources = copy.deepcopy([e["conversations"] for e in sources])
+# data_dict = preprocess(
+# sources,
+# self.tokenizer,
+# has_image=('image' in self.list_data_dict[i]))
+# if isinstance(i, int):
+# data_dict = dict(input_ids=data_dict["input_ids"][0],
+# labels=data_dict["labels"][0])
+
+# # image exist in the data
+# if 'image' in self.list_data_dict[i]:
+# data_dict['image'] = image
+# elif self.data_args.is_multimodal:
+# # image does not exist in the data, but the model is multimodal
+# crop_size = self.data_args.image_processor.crop_size
+# data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+# return data_dict
+
+ def next_rand(self):
+ import random
+ return random.randint(0,len(self)-1)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ while True:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ if 'image' in sources[0]:
+ image_file = self.list_data_dict[i]['image']
+
+ image_folder = self.data_args.image_folder
+ processor = self.data_args.image_processor
+ from pathlib import Path
+ #if not Path(os.path.join(image_folder, image_file)).exists():
+ # i = self.next_rand()
+ # continue
+ if isinstance(image_file, list):
+ # Multiple Images as Input
+ try:
+ image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file]
+ except Exception as ex:
+ print(ex)
+ i = self.next_rand()
+ continue
+ if self.data_args.image_aspect_ratio == 'pad':
+ image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ elif os.path.join(image_folder, image_file).endswith("mp4"):
+ # Video as Input
+ image = load_video(os.path.join(image_folder, image_file))
+ if self.data_args.image_aspect_ratio == 'pad':
+ image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values']
+ else:
+ try:
+ image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
+ except Exception as ex:
+ print(ex)
+ i = self.next_rand()
+ continue
+ if self.data_args.image_aspect_ratio == 'pad':
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ sources = preprocess_multimodal(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.data_args)
+ else:
+
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=('image' in self.list_data_dict[i]))
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if 'image' in self.list_data_dict[i]:
+ data_dict['image'] = image
+ elif self.data_args.is_multimodal:
+ # image does not exist in the data, but the model is multimodal
+ crop_size = self.data_args.image_processor.crop_size
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ labels = labels[:, :self.tokenizer.model_max_length]
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ data_args=data_args)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator)
+
+
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+ model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if not training_args.freeze_vision_model and training_args.bits in [4, 8]:
+ model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device)
+ else:
+ vision_tower = model.get_model().vision_model
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ if training_args.tune_visual_abstractor and training_args.bits in [4, 8]:
+ model.get_model().visual_abstractor.to(dtype=compute_dtype, device=training_args.device)
+ else:
+ visual_abstractor = model.get_model().visual_abstractor
+ visual_abstractor.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = CLIPImageProcessor.from_pretrained(model_args.model_name_or_path)
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
+ model.config.tune_visual_abstractor = model_args.tune_visual_abstractor = training_args.tune_visual_abstractor
+ ic(training_args.tune_visual_abstractor)
+ model.requires_grad_(True)
+ if training_args.tune_visual_abstractor:
+ # model.requires_grad_(False)
+ for p in model.get_model().visual_abstractor.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_vision_model = training_args.freeze_vision_model
+ ic(training_args.freeze_vision_model)
+ if training_args.freeze_vision_model:
+ for p in model.get_model().vision_model.parameters():
+ p.requires_grad = False
+
+ model.config.visual_abstractor_lr = training_args.visual_abstractor_lr
+
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+ trainer = MPLUGOwl2Trainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ # if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ # trainer.train(resume_from_checkpoint=True)
+ # else:
+ # trainer.train()
+
+ # TODO I dont like auto resume << REMOVE IT AND UNCOMMENT THE ABOVE CODE
+ trainer.train()
+
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
\ No newline at end of file
diff --git a/q_align/train/train_mem.py b/q_align/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..4698657df8e14de319a0c117948eb8fd8a4c67b3
--- /dev/null
+++ b/q_align/train/train_mem.py
@@ -0,0 +1,13 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from q_align.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+
+replace_llama_attn_with_flash_attn()
+
+from q_align.train.train import train
+
+if __name__ == "__main__":
+ train()
\ No newline at end of file
diff --git a/q_align/utils.py b/q_align/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..378dbc81cce765dea2cd138ae50f242fc557c2c9
--- /dev/null
+++ b/q_align/utils.py
@@ -0,0 +1,128 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+
+
+from q_align.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
\ No newline at end of file