haoning.wu commited on
Commit
e63f3e2
1 Parent(s): a23f4af

Scorer Starts

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +117 -1
  2. q_align/.ipynb_checkpoints/utils-checkpoint.py +128 -0
  3. q_align/__init__.py +1 -0
  4. q_align/__pycache__/__init__.cpython-310.pyc +0 -0
  5. q_align/__pycache__/__init__.cpython-311.pyc +0 -0
  6. q_align/__pycache__/__init__.cpython-39.pyc +0 -0
  7. q_align/__pycache__/constants.cpython-310.pyc +0 -0
  8. q_align/__pycache__/constants.cpython-311.pyc +0 -0
  9. q_align/__pycache__/constants.cpython-39.pyc +0 -0
  10. q_align/__pycache__/conversation.cpython-310.pyc +0 -0
  11. q_align/__pycache__/conversation.cpython-311.pyc +0 -0
  12. q_align/__pycache__/conversation.cpython-39.pyc +0 -0
  13. q_align/__pycache__/mm_utils.cpython-310.pyc +0 -0
  14. q_align/__pycache__/mm_utils.cpython-311.pyc +0 -0
  15. q_align/__pycache__/mm_utils.cpython-39.pyc +0 -0
  16. q_align/__pycache__/utils.cpython-311.pyc +0 -0
  17. q_align/constants.py +9 -0
  18. q_align/conversation.py +301 -0
  19. q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py +164 -0
  20. q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py +150 -0
  21. q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py +156 -0
  22. q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py +155 -0
  23. q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py +167 -0
  24. q_align/evaluate/__pycache__/scorer.cpython-311.pyc +0 -0
  25. q_align/evaluate/eval.py +138 -0
  26. q_align/evaluate/iaa_eval.py +164 -0
  27. q_align/evaluate/iqa4vqa_eval.py +150 -0
  28. q_align/evaluate/iqa_eval.py +156 -0
  29. q_align/evaluate/scorer.py +155 -0
  30. q_align/evaluate/vqa_eval.py +167 -0
  31. q_align/mm_utils.py +112 -0
  32. q_align/model/__init__.py +2 -0
  33. q_align/model/__pycache__/__init__.cpython-310.pyc +0 -0
  34. q_align/model/__pycache__/__init__.cpython-311.pyc +0 -0
  35. q_align/model/__pycache__/__init__.cpython-39.pyc +0 -0
  36. q_align/model/__pycache__/builder.cpython-310.pyc +0 -0
  37. q_align/model/__pycache__/builder.cpython-311.pyc +0 -0
  38. q_align/model/__pycache__/builder.cpython-39.pyc +0 -0
  39. q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc +0 -0
  40. q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc +0 -0
  41. q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc +0 -0
  42. q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc +0 -0
  43. q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc +0 -0
  44. q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc +0 -0
  45. q_align/model/__pycache__/modeling_llama2.cpython-310.pyc +0 -0
  46. q_align/model/__pycache__/modeling_llama2.cpython-311.pyc +0 -0
  47. q_align/model/__pycache__/modeling_llama2.cpython-39.pyc +0 -0
  48. q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc +0 -0
  49. q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc +0 -0
  50. q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc +0 -0
app.py CHANGED
@@ -1,3 +1,119 @@
1
  import gradio as gr
2
 
3
- gr.load("models/q-future/one-align").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ import argparse
4
+ import datetime
5
+ import json
6
+ import os
7
+ import time
8
+
9
+ import gradio as gr
10
+ import requests
11
+ from PIL import Image
12
+
13
+ from q_align.model.builder import load_pretrained_model
14
+
15
+ from q_align.conversation import (default_conversation, conv_templates,
16
+ SeparatorStyle)
17
+ from q_align.constants import LOGDIR
18
+ from q_align.utils import (build_logger, server_error_msg,
19
+ violates_moderation, moderation_msg)
20
+
21
+ from q_align.evaluate.scorer import QAlignScorer, QAlignAestheticScorer, QAlignVideoScorer
22
+
23
+ import gradio as gr
24
+
25
+ def load_video(video_file):
26
+ from decord import VideoReader
27
+ vr = VideoReader(video_file)
28
+
29
+ # Get video frame rate
30
+ fps = vr.get_avg_fps()
31
+
32
+ # Calculate frame indices for 1fps
33
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
34
+ frames = vr.get_batch(frame_indices).asnumpy()
35
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
36
+
37
+
38
+ pretrained="q-future/one-align"
39
+ device="cuda:0"
40
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
41
+
42
+ iqa_scorer = QAlignScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
43
+ iaa_scorer = QAlignAestheticScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
44
+ vqa_scorer = QAlignVideoScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
45
+
46
+ scorers = {"Image Aesthetics (IAA)": iaa_scorer, "Image Quality (IQA)": iqa_scorer, "Video Quality (VQA)": vqa_scorer}
47
+
48
+ LEVELS = ["excellent (5)", "good (4)", "fair (3)", "poor (2)", "bad (1)"]
49
+ scores = [5,4,3,2,1]
50
+ def image_classifier(input_img, input_vid, scorer_type):
51
+ if scorer_type is None:
52
+ scorer_type = "Image Quality (IQA)"
53
+ this_scorer = scorers[scorer_type]
54
+ if input_vid is not None:
55
+ input_ = load_video(input_vid)
56
+ elif input_img is not None:
57
+ input_ = [input_img]
58
+ if "Video" in scorer_type:
59
+ input_ = [input_]
60
+ probs = this_scorer(input_).mean(0).tolist()
61
+ prob_dict = {LEVEL: prob for LEVEL, prob in zip(LEVELS, probs)}
62
+ score = sum([prob * score for score, prob in zip(scores, probs)])
63
+ return prob_dict, score
64
+
65
+ title_markdown = ("""
66
+
67
+ <h1 align="center">Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined Levels</h1>
68
+
69
+ <h3 align="center"> One Unified Model for Visual scoring. </h3>
70
+
71
+ <h5 align="center">
72
+ <a href="https://teowu.github.io/" target="_blank">Haoning Wu</a><sup>1</sup><sup>*</sup><sup>+</sup>,
73
+ <a href="https://github.com/zzc-1998" target="_blank">Zicheng Zhang</a><sup>2</sup><sup>*</sup>,
74
+ <a href="https://sites.google.com/view/r-panda" target="_blank">Weixia Zhang</a><sup>2</sup>,
75
+ <a href="https://chaofengc.github.io" target="_blank">Chaofeng Chen</a><sup>1</sup>,
76
+ <a href="https://liaoliang92.github.io" target="_blank">Liang Liao</a><sup>1</sup>,
77
+ <a href="https://github.com/lcysyzxdxc" target="_blank">Chunyi Li</a><sup>2</sup>,
78
+ </h5>
79
+
80
+
81
+ <h5 align="center">
82
+ <a href="https://github.com/YixuanGao98" target="_blank">Yixuan Gao</a><sup>2</sup>,
83
+ <a href="https://github.com/AnnanWangDaniel" target="_blank">Annan Wang</a><sup>1</sup>,
84
+ <a href="https://github.com/ZhangErliCarl/" target="_blank">Erli Zhang</a><sup>1</sup>,
85
+ <a href="https://wenxiusun.com" target="_blank">Wenxiu Sun</a><sup>3</sup>,
86
+ <a href="https://scholar.google.com/citations?user=uT9CtPYAAAAJ&hl=en" target="_blank">Qiong Yan</a><sup>3</sup>,
87
+ <a href="https://sites.google.com/site/minxiongkuo/" target="_blank">Xiongkuo Min</a><sup>2</sup>,
88
+ <a href="https://ee.sjtu.edu.cn/en/FacultyDetail.aspx?id=24&infoid=153&flag=153" target="_blank">Guangtao Zhai</a><sup>2</sup><sup>#</sup>,
89
+ <a href="https://personal.ntu.edu.sg/wslin/Home.html" target="_blank">Weisi Lin</a><sup>1</sup><sup>#</sup>
90
+ </h5>
91
+
92
+ <h5 align="center">
93
+ <sup>1</sup>Nanyang Technological University, <sup>2</sup>Shanghai Jiao Tong University, <sup>3</sup>Sensetime Research
94
+ </h5>
95
+ <h5 align="center">
96
+ <sup>*</sup>Equal contribution. <sup>+</sup>Project Lead. <sup>#</sup>Corresponding author(s).
97
+ </h5>
98
+
99
+ <h4 align="center"> If you like the OneScorer, please give us a star ✨ on <a href='https://github.com/Q-Future/Q-Align'>GitHub</a> for latest update. </h4>
100
+
101
+ <h5 align="center">
102
+ <div style="display:flex; gap: 0.25rem;" align="center">
103
+ <a href='https://q-align.github.io'><img src='https://img.shields.io/badge/Homepage-green'></a>
104
+ <a href='https://github.com/Q-Future/Q-Align'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
105
+ <a href="https://Q-Future.github.io/Q-Align/fig/Q_Align_v0_1_preview.pdf"><img src="https://img.shields.io/badge/Technical-Report-red"></a>
106
+ <a href='https://github.com/Q-Future/Q-Align/stargazers'><img src='https://img.shields.io/github/stars/Q-Future/Q-Align.svg?style=social'></a>
107
+ </div>
108
+ </h5>
109
+
110
+ """)
111
+
112
+
113
+ input_img = gr.Image(type='pil', label="Upload an Image")
114
+ input_vid = gr.Video(label="Upload a Video (will INGORE the image if a video is uploaded)", info="If a video is uploaded, the image uploaded will be ignored.")
115
+
116
+ labels = gr.Label(label="Probabilities of rating levels:")
117
+ number = gr.Number(label="Output score:", info="Range in [1,5]. Higher is better.")
118
+ demo = gr.Interface(fn=image_classifier, inputs=[input_img, input_vid, gr.Radio(["Image Aesthetics (IAA)", "Image Quality (IQA)", "Video Quality (VQA)"], label="Task", info="Which Scorer will you need?"),], outputs=[labels, number], title="OneScorer", description=title_markdown)
119
+ demo.launch(share=True)
q_align/.ipynb_checkpoints/utils-checkpoint.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+
10
+
11
+ from q_align.constants import LOGDIR
12
+
13
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
14
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15
+
16
+ handler = None
17
+
18
+
19
+ def build_logger(logger_name, logger_filename):
20
+ global handler
21
+
22
+ formatter = logging.Formatter(
23
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24
+ datefmt="%Y-%m-%d %H:%M:%S",
25
+ )
26
+
27
+ # Set the format of root handlers
28
+ if not logging.getLogger().handlers:
29
+ logging.basicConfig(level=logging.INFO)
30
+ logging.getLogger().handlers[0].setFormatter(formatter)
31
+
32
+ # Redirect stdout and stderr to loggers
33
+ stdout_logger = logging.getLogger("stdout")
34
+ stdout_logger.setLevel(logging.INFO)
35
+ sl = StreamToLogger(stdout_logger, logging.INFO)
36
+ sys.stdout = sl
37
+
38
+ stderr_logger = logging.getLogger("stderr")
39
+ stderr_logger.setLevel(logging.ERROR)
40
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
41
+ sys.stderr = sl
42
+
43
+ # Get logger
44
+ logger = logging.getLogger(logger_name)
45
+ logger.setLevel(logging.INFO)
46
+
47
+ # Add a file handler for all loggers
48
+ if handler is None:
49
+ os.makedirs(LOGDIR, exist_ok=True)
50
+ filename = os.path.join(LOGDIR, logger_filename)
51
+ handler = logging.handlers.TimedRotatingFileHandler(
52
+ filename, when='D', utc=True)
53
+ handler.setFormatter(formatter)
54
+
55
+ for name, item in logging.root.manager.loggerDict.items():
56
+ if isinstance(item, logging.Logger):
57
+ item.addHandler(handler)
58
+
59
+ return logger
60
+
61
+
62
+ class StreamToLogger(object):
63
+ """
64
+ Fake file-like stream object that redirects writes to a logger instance.
65
+ """
66
+ def __init__(self, logger, log_level=logging.INFO):
67
+ self.terminal = sys.stdout
68
+ self.logger = logger
69
+ self.log_level = log_level
70
+ self.linebuf = ''
71
+
72
+ def __getattr__(self, attr):
73
+ return getattr(self.terminal, attr)
74
+
75
+ def write(self, buf):
76
+ temp_linebuf = self.linebuf + buf
77
+ self.linebuf = ''
78
+ for line in temp_linebuf.splitlines(True):
79
+ # From the io.TextIOWrapper docs:
80
+ # On output, if newline is None, any '\n' characters written
81
+ # are translated to the system default line separator.
82
+ # By default sys.stdout.write() expects '\n' newlines and then
83
+ # translates them so this is still cross platform.
84
+ if line[-1] == '\n':
85
+ self.logger.log(self.log_level, line.rstrip())
86
+ else:
87
+ self.linebuf += line
88
+
89
+ def flush(self):
90
+ if self.linebuf != '':
91
+ self.logger.log(self.log_level, self.linebuf.rstrip())
92
+ self.linebuf = ''
93
+
94
+
95
+ def disable_torch_init():
96
+ """
97
+ Disable the redundant torch default initialization to accelerate model creation.
98
+ """
99
+ import torch
100
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
101
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
102
+
103
+
104
+ def violates_moderation(text):
105
+ """
106
+ Check whether the text violates OpenAI moderation API.
107
+ """
108
+ url = "https://api.openai.com/v1/moderations"
109
+ headers = {"Content-Type": "application/json",
110
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
111
+ text = text.replace("\n", "")
112
+ data = "{" + '"input": ' + f'"{text}"' + "}"
113
+ data = data.encode("utf-8")
114
+ try:
115
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
116
+ flagged = ret.json()["results"][0]["flagged"]
117
+ except requests.exceptions.RequestException as e:
118
+ flagged = False
119
+ except KeyError as e:
120
+ flagged = False
121
+
122
+ return flagged
123
+
124
+
125
+ def pretty_print_semaphore(semaphore):
126
+ if semaphore is None:
127
+ return "None"
128
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
q_align/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import MPLUGOwl2LlamaForCausalLM
q_align/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (213 Bytes). View file
 
q_align/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (218 Bytes). View file
 
q_align/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (184 Bytes). View file
 
q_align/__pycache__/constants.cpython-310.pyc ADDED
Binary file (358 Bytes). View file
 
q_align/__pycache__/constants.cpython-311.pyc ADDED
Binary file (371 Bytes). View file
 
q_align/__pycache__/constants.cpython-39.pyc ADDED
Binary file (329 Bytes). View file
 
q_align/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (8.55 kB). View file
 
q_align/__pycache__/conversation.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
q_align/__pycache__/conversation.cpython-39.pyc ADDED
Binary file (8.53 kB). View file
 
q_align/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (4.57 kB). View file
 
q_align/__pycache__/mm_utils.cpython-311.pyc ADDED
Binary file (8.4 kB). View file
 
q_align/__pycache__/mm_utils.cpython-39.pyc ADDED
Binary file (4.51 kB). View file
 
q_align/__pycache__/utils.cpython-311.pyc ADDED
Binary file (6.92 kB). View file
 
q_align/constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "./demo_logs"
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<|image|>"
q_align/conversation.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ from q_align.constants import DEFAULT_IMAGE_TOKEN
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ TWO_NO_SYS = auto()
11
+ MPT = auto()
12
+ PLAIN = auto()
13
+ LLAMA_2 = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ messages = self.messages
32
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
33
+ messages = self.messages.copy()
34
+ init_role, init_msg = messages[0].copy()
35
+ # init_msg = init_msg[0].replace("<image>", "").strip()
36
+ # if 'mmtag' in self.version:
37
+ # messages[0] = (init_role, init_msg)
38
+ # messages.insert(0, (self.roles[0], "<Image><image></Image>"))
39
+ # messages.insert(1, (self.roles[1], "Received."))
40
+ # else:
41
+ # messages[0] = (init_role, "<image>\n" + init_msg)
42
+ init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
43
+ messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
65
+ seps = [self.sep, self.sep2]
66
+ ret = ""
67
+ for i, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + seps[i % 2]
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.MPT:
75
+ ret = self.system + self.sep
76
+ for role, message in messages:
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + message + self.sep
81
+ else:
82
+ ret += role
83
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
84
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
85
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
86
+ ret = ""
87
+
88
+ for i, (role, message) in enumerate(messages):
89
+ if i == 0:
90
+ assert message, "first message should not be none"
91
+ assert role == self.roles[0], "first message should come from user"
92
+ if message:
93
+ if type(message) is tuple:
94
+ message, _, _ = message
95
+ if i == 0: message = wrap_sys(self.system) + message
96
+ if i % 2 == 0:
97
+ message = wrap_inst(message)
98
+ ret += self.sep + message
99
+ else:
100
+ ret += " " + message + " " + self.sep2
101
+ else:
102
+ ret += ""
103
+ ret = ret.lstrip(self.sep)
104
+ elif self.sep_style == SeparatorStyle.PLAIN:
105
+ seps = [self.sep, self.sep2]
106
+ ret = self.system
107
+ for i, (role, message) in enumerate(messages):
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, _, _ = message
111
+ ret += message + seps[i % 2]
112
+ else:
113
+ ret += ""
114
+ else:
115
+ raise ValueError(f"Invalid style: {self.sep_style}")
116
+
117
+ return ret
118
+
119
+ def append_message(self, role, message):
120
+ self.messages.append([role, message])
121
+
122
+ def get_images(self, return_pil=False):
123
+ images = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ from PIL import Image
130
+ msg, image, image_process_mode = msg
131
+ if image_process_mode == "Pad":
132
+ def expand2square(pil_img, background_color=(122, 116, 104)):
133
+ width, height = pil_img.size
134
+ if width == height:
135
+ return pil_img
136
+ elif width > height:
137
+ result = Image.new(pil_img.mode, (width, width), background_color)
138
+ result.paste(pil_img, (0, (width - height) // 2))
139
+ return result
140
+ else:
141
+ result = Image.new(pil_img.mode, (height, height), background_color)
142
+ result.paste(pil_img, ((height - width) // 2, 0))
143
+ return result
144
+ image = expand2square(image)
145
+ elif image_process_mode in ["Default", "Crop"]:
146
+ pass
147
+ elif image_process_mode == "Resize":
148
+ image = image.resize((336, 336))
149
+ else:
150
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if longest_edge != max(image.size):
158
+ if H > W:
159
+ H, W = longest_edge, shortest_edge
160
+ else:
161
+ H, W = shortest_edge, longest_edge
162
+ image = image.resize((W, H))
163
+ if return_pil:
164
+ images.append(image)
165
+ else:
166
+ buffered = BytesIO()
167
+ image.save(buffered, format="PNG")
168
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
+ images.append(img_b64_str)
170
+ return images
171
+
172
+ def to_gradio_chatbot(self):
173
+ ret = []
174
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
175
+ if i % 2 == 0:
176
+ if type(msg) is tuple:
177
+ import base64
178
+ from io import BytesIO
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ msg = img_str + msg.replace('<|image|>', '').strip()
196
+ ret.append([msg, None])
197
+ else:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret[-1][-1] = msg
201
+ return ret
202
+
203
+ def copy(self):
204
+ return Conversation(
205
+ system=self.system,
206
+ roles=self.roles,
207
+ messages=[[x, y] for x, y in self.messages],
208
+ offset=self.offset,
209
+ sep_style=self.sep_style,
210
+ sep=self.sep,
211
+ sep2=self.sep2,
212
+ version=self.version)
213
+
214
+ def dict(self):
215
+ if len(self.get_images()) > 0:
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+ return {
225
+ "system": self.system,
226
+ "roles": self.roles,
227
+ "messages": self.messages,
228
+ "offset": self.offset,
229
+ "sep": self.sep,
230
+ "sep2": self.sep2,
231
+ }
232
+
233
+
234
+ conv_vicuna_v0 = Conversation(
235
+ system="A chat between a curious human and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
237
+ roles=("Human", "Assistant"),
238
+ messages=(
239
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
240
+ ("Assistant",
241
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
242
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
243
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
244
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
245
+ "renewable and non-renewable energy sources:\n"
246
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
247
+ "energy sources are finite and will eventually run out.\n"
248
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
249
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
250
+ "and other negative effects.\n"
251
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
252
+ "have lower operational costs than non-renewable sources.\n"
253
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
254
+ "locations than non-renewable sources.\n"
255
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
256
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
257
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
258
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
259
+ ),
260
+ offset=2,
261
+ sep_style=SeparatorStyle.SINGLE,
262
+ sep="###",
263
+ )
264
+
265
+ conv_vicuna_v1 = Conversation(
266
+ system="A chat between a curious user and an artificial intelligence assistant. "
267
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="v1",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.TWO,
273
+ sep=" ",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_mplug_owl2 = Conversation(
278
+ system="A chat between a curious human and an artificial intelligence assistant. "
279
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
280
+ roles=("USER", "ASSISTANT"),
281
+ version="v1",
282
+ messages=(),
283
+ offset=0,
284
+ sep_style=SeparatorStyle.TWO_NO_SYS,
285
+ sep=" ",
286
+ sep2="</s>",
287
+ )
288
+
289
+ # default_conversation = conv_vicuna_v1
290
+ default_conversation = conv_mplug_owl2
291
+ conv_templates = {
292
+ "default": conv_vicuna_v0,
293
+ "v0": conv_vicuna_v0,
294
+ "v1": conv_vicuna_v1,
295
+ "vicuna_v1": conv_vicuna_v1,
296
+ "mplug_owl2": conv_mplug_owl2,
297
+ }
298
+
299
+
300
+ if __name__ == "__main__":
301
+ print(default_conversation.get_prompt())
q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+ from PIL import ImageFile
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+
13
+ import requests
14
+ from PIL import Image
15
+ from io import BytesIO
16
+ from transformers import TextStreamer
17
+
18
+ from scipy.stats import spearmanr, pearsonr
19
+
20
+
21
+ import json
22
+ from tqdm import tqdm
23
+ from collections import defaultdict
24
+
25
+ import os
26
+
27
+ def wa5(logits):
28
+ import numpy as np
29
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
30
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
31
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
32
+
33
+
34
+
35
+
36
+ def disable_torch_init():
37
+ """
38
+ Disable the redundant torch default initialization to accelerate model creation.
39
+ """
40
+ import torch
41
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
42
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
43
+
44
+
45
+ def load_image(image_file):
46
+ if image_file.startswith('http://') or image_file.startswith('https://'):
47
+ response = requests.get(image_file)
48
+ image = Image.open(BytesIO(response.content)).convert('RGB')
49
+ else:
50
+ image = Image.open(image_file).convert('RGB')
51
+ return image
52
+
53
+
54
+ def main(args):
55
+ # Model
56
+ disable_torch_init()
57
+
58
+ model_name = get_model_name_from_path(args.model_path)
59
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
60
+
61
+
62
+ import json
63
+
64
+
65
+ image_path = "playground/data/"
66
+
67
+
68
+ json_prefix = "playground/data/test_jsons/"
69
+ jsons = [
70
+ json_prefix + "test_ava.json",
71
+ ]
72
+
73
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
74
+
75
+
76
+ conv_mode = "mplug_owl2"
77
+
78
+ inp = "How would you rate the aesthetics of this image?"
79
+
80
+ conv = conv_templates[conv_mode].copy()
81
+ inp = DEFAULT_IMAGE_TOKEN + inp
82
+ conv.append_message(conv.roles[0], inp)
83
+ image = None
84
+
85
+ conv.append_message(conv.roles[1], None)
86
+ prompt = conv.get_prompt() + " The aesthetics of the image is"
87
+
88
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
89
+ print(toks)
90
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
91
+ print(ids_)
92
+
93
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
94
+
95
+ for json_ in jsons:
96
+ with open(json_) as f:
97
+ iqadata = json.load(f)
98
+
99
+ image_tensors = []
100
+ batch_data = []
101
+ prs, gts = [], []
102
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
103
+ filename = llddata["image"]
104
+ llddata["logits"] = defaultdict(float)
105
+
106
+
107
+
108
+ image = load_image(image_path + filename)
109
+ def expand2square(pil_img, background_color):
110
+ width, height = pil_img.size
111
+ if width == height:
112
+ return pil_img
113
+ elif width > height:
114
+ result = Image.new(pil_img.mode, (width, width), background_color)
115
+ result.paste(pil_img, (0, (width - height) // 2))
116
+ return result
117
+ else:
118
+ result = Image.new(pil_img.mode, (height, height), background_color)
119
+ result.paste(pil_img, ((height - width) // 2, 0))
120
+ return result
121
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
122
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
123
+
124
+ image_tensors.append(image_tensor)
125
+ batch_data.append(llddata)
126
+
127
+ if i % 8 == 7 or i == len(iqadata) - 1:
128
+ with torch.inference_mode():
129
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
130
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
131
+
132
+ for j, xllddata in enumerate(batch_data):
133
+ for tok, id_ in zip(toks, ids_):
134
+ xllddata["logits"][tok] += output_logits[j,id_].item()
135
+ xllddata["score"] = wa5(xllddata["logits"])
136
+ # print(llddata)
137
+ prs.append(xllddata["score"])
138
+ gts.append(xllddata["gt_score"])
139
+ json_ = json_.replace("combined/", "combined-")
140
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
141
+ json.dump(xllddata, wf)
142
+
143
+ image_tensors = []
144
+ batch_data = []
145
+
146
+ #if i > 0 and i % 200 == 0:
147
+ # print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
148
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
149
+
150
+
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
154
+ parser.add_argument("--model-base", type=str, default=None)
155
+ parser.add_argument("--device", type=str, default="cuda:0")
156
+ parser.add_argument("--conv-mode", type=str, default=None)
157
+ parser.add_argument("--temperature", type=float, default=0.2)
158
+ parser.add_argument("--max-new-tokens", type=int, default=512)
159
+ parser.add_argument("--load-8bit", action="store_true")
160
+ parser.add_argument("--load-4bit", action="store_true")
161
+ parser.add_argument("--debug", action="store_true")
162
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
163
+ args = parser.parse_args()
164
+ main(args)
q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+ from decord import VideoReader
17
+
18
+
19
+ import json
20
+ from tqdm import tqdm
21
+ from collections import defaultdict
22
+
23
+ import os
24
+
25
+
26
+
27
+
28
+ def disable_torch_init():
29
+ """
30
+ Disable the redundant torch default initialization to accelerate model creation.
31
+ """
32
+ import torch
33
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
34
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
35
+
36
+
37
+ def load_video(video_file):
38
+ vr = VideoReader(video_file)
39
+
40
+ # Get video frame rate
41
+ fps = vr.get_avg_fps()
42
+
43
+ # Calculate frame indices for 1fps
44
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
45
+ frames = vr.get_batch(frame_indices).asnumpy()
46
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
47
+
48
+
49
+ def main(args):
50
+ # Model
51
+ disable_torch_init()
52
+
53
+ model_name = get_model_name_from_path(args.model_path)
54
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
55
+
56
+
57
+ import json
58
+
59
+
60
+ image_paths = [
61
+ "playground/data/",
62
+ "playground/data/",
63
+ "playground/data/KoNViD_1k_videos/",
64
+ "playground/data/maxwell/",
65
+ ]
66
+
67
+ json_prefix = "playground/data/test_jsons/"
68
+ jsons = [
69
+ json_prefix + "test_lsvq.json",
70
+ json_prefix + "test_lsvq_1080p.json",
71
+ json_prefix + "konvid.json",
72
+ json_prefix + "maxwell_test.json",
73
+ ]
74
+
75
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
76
+
77
+
78
+ conv_mode = "mplug_owl2"
79
+
80
+ inp = "How would you rate the quality of this image?"
81
+
82
+ conv = conv_templates[conv_mode].copy()
83
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
84
+ conv.append_message(conv.roles[0], inp)
85
+ image = None
86
+
87
+ conv.append_message(conv.roles[1], None)
88
+ prompt = conv.get_prompt() + " The quality of the image is"
89
+
90
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
91
+ print(toks)
92
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
93
+ print(ids_)
94
+
95
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
96
+
97
+ for image_path, json_ in zip(image_paths, jsons):
98
+ with open(json_) as f:
99
+ iqadata = json.load(f)
100
+ try:
101
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
102
+ filename = llddata["img_path"]
103
+ llddata["logits"] = defaultdict(float)
104
+
105
+ image = load_video(image_path + filename)
106
+ def expand2square(pil_img, background_color):
107
+ width, height = pil_img.size
108
+ if width == height:
109
+ return pil_img
110
+ elif width > height:
111
+ result = Image.new(pil_img.mode, (width, width), background_color)
112
+ result.paste(pil_img, (0, (width - height) // 2))
113
+ return result
114
+ else:
115
+ result = Image.new(pil_img.mode, (height, height), background_color)
116
+ result.paste(pil_img, ((height - width) // 2, 0))
117
+ return result
118
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
119
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
120
+
121
+
122
+ if True:
123
+ with torch.inference_mode():
124
+ output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
125
+ images=image_tensor)["logits"][:,-1]
126
+
127
+ for tok, id_ in zip(toks, ids_):
128
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
129
+ # print(llddata)
130
+ json_ = json_.replace("combined/", "combined-")
131
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
132
+ json.dump(llddata, wf)
133
+ except:
134
+ continue
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model-path", type=str, default="q-future/q-align-image")
140
+ parser.add_argument("--model-base", type=str, default=None)
141
+ parser.add_argument("--device", type=str, default="cuda:0")
142
+ parser.add_argument("--conv-mode", type=str, default=None)
143
+ parser.add_argument("--temperature", type=float, default=0.2)
144
+ parser.add_argument("--max-new-tokens", type=int, default=512)
145
+ parser.add_argument("--load-8bit", action="store_true")
146
+ parser.add_argument("--load-4bit", action="store_true")
147
+ parser.add_argument("--debug", action="store_true")
148
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
149
+ args = parser.parse_args()
150
+ main(args)
q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+ import json
17
+ from tqdm import tqdm
18
+ from collections import defaultdict
19
+
20
+ import os
21
+
22
+
23
+
24
+
25
+ def disable_torch_init():
26
+ """
27
+ Disable the redundant torch default initialization to accelerate model creation.
28
+ """
29
+ import torch
30
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
31
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
32
+
33
+
34
+ def load_image(image_file):
35
+ if image_file.startswith('http://') or image_file.startswith('https://'):
36
+ response = requests.get(image_file)
37
+ image = Image.open(BytesIO(response.content)).convert('RGB')
38
+ else:
39
+ image = Image.open(image_file).convert('RGB')
40
+ return image
41
+
42
+
43
+ def main(args):
44
+ # Model
45
+ disable_torch_init()
46
+
47
+ model_name = get_model_name_from_path(args.model_path)
48
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
49
+
50
+
51
+ import json
52
+
53
+
54
+ image_path = "playground/data/"
55
+
56
+
57
+ json_prefix = "playground/data/test_jsons/"
58
+ jsons = [
59
+ json_prefix + "test_imagerewarddb.json",
60
+ json_prefix + "test_koniq.json",
61
+ json_prefix + "test_spaq.json",
62
+ json_prefix + "test_kadid.json",
63
+ json_prefix + "livec.json",
64
+ json_prefix + "agi.json",
65
+ json_prefix + "live.json",
66
+ json_prefix + "csiq.json",
67
+ ]
68
+
69
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
70
+
71
+
72
+ conv_mode = "mplug_owl2"
73
+
74
+ inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?"
75
+
76
+ conv = conv_templates[conv_mode].copy()
77
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
78
+ conv.append_message(conv.roles[0], inp)
79
+ image = None
80
+
81
+ conv.append_message(conv.roles[1], None)
82
+ prompt = conv.get_prompt() + " The quality of the image is"
83
+
84
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
85
+ print(toks)
86
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
87
+ print(ids_)
88
+
89
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
90
+
91
+ for json_ in jsons:
92
+ with open(json_) as f:
93
+ iqadata = json.load(f)
94
+
95
+ image_tensors = []
96
+ batch_data = []
97
+
98
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
99
+ if True:
100
+ try:
101
+ filename = llddata["image"]
102
+ except:
103
+ filename = llddata["img_path"]
104
+ llddata["logits"] = defaultdict(float)
105
+
106
+ image = load_image(image_path + filename)
107
+ def expand2square(pil_img, background_color):
108
+ width, height = pil_img.size
109
+ if width == height:
110
+ return pil_img
111
+ elif width > height:
112
+ result = Image.new(pil_img.mode, (width, width), background_color)
113
+ result.paste(pil_img, (0, (width - height) // 2))
114
+ return result
115
+ else:
116
+ result = Image.new(pil_img.mode, (height, height), background_color)
117
+ result.paste(pil_img, ((height - width) // 2, 0))
118
+ return result
119
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
120
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
121
+
122
+ image_tensors.append(image_tensor)
123
+ batch_data.append(llddata)
124
+
125
+ if i % 8 == 7 or i == len(iqadata) - 1:
126
+ with torch.inference_mode():
127
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
128
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
129
+
130
+ for j, xllddata in enumerate(batch_data):
131
+ for tok, id_ in zip(toks, ids_):
132
+ xllddata["logits"][tok] += output_logits[j,id_].item()
133
+ # print(llddata)
134
+ json_ = json_.replace("combined/", "combined-")
135
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
136
+ json.dump(xllddata, wf)
137
+
138
+ image_tensors = []
139
+ batch_data = []
140
+
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
146
+ parser.add_argument("--model-base", type=str, default=None)
147
+ parser.add_argument("--device", type=str, default="cuda:0")
148
+ parser.add_argument("--conv-mode", type=str, default=None)
149
+ parser.add_argument("--temperature", type=float, default=0.2)
150
+ parser.add_argument("--max-new-tokens", type=int, default=512)
151
+ parser.add_argument("--load-8bit", action="store_true")
152
+ parser.add_argument("--load-4bit", action="store_true")
153
+ parser.add_argument("--debug", action="store_true")
154
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
155
+ args = parser.parse_args()
156
+ main(args)
q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from typing import List
7
+
8
+ from q_align.model.builder import load_pretrained_model
9
+
10
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
11
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
12
+
13
+ def load_video(video_file):
14
+ from decord import VideoReader
15
+ vr = VideoReader(video_file)
16
+
17
+ # Get video frame rate
18
+ fps = vr.get_avg_fps()
19
+
20
+ # Calculate frame indices for 1fps
21
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
22
+ frames = vr.get_batch(frame_indices).asnumpy()
23
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
24
+
25
+
26
+ class QAlignScorer(nn.Module):
27
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
28
+ super().__init__()
29
+ if model is None:
30
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
31
+ prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
32
+
33
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
34
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
35
+
36
+ self.tokenizer = tokenizer
37
+ self.model = model
38
+ self.image_processor = image_processor
39
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
40
+
41
+ def expand2square(self, pil_img, background_color):
42
+ width, height = pil_img.size
43
+ if width == height:
44
+ return pil_img
45
+ elif width > height:
46
+ result = Image.new(pil_img.mode, (width, width), background_color)
47
+ result.paste(pil_img, (0, (width - height) // 2))
48
+ return result
49
+ else:
50
+ result = Image.new(pil_img.mode, (height, height), background_color)
51
+ result.paste(pil_img, ((height - width) // 2, 0))
52
+ return result
53
+
54
+ def forward(self, image: List[Image.Image]):
55
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
56
+ with torch.inference_mode():
57
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
58
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
59
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
60
+
61
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
62
+
63
+
64
+ class QAlignAestheticScorer(nn.Module):
65
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
66
+ super().__init__()
67
+ if model is None:
68
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
69
+ prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is"
70
+
71
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
72
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
73
+
74
+ self.tokenizer = tokenizer
75
+ self.model = model
76
+ self.image_processor = image_processor
77
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
78
+
79
+ def expand2square(self, pil_img, background_color):
80
+ width, height = pil_img.size
81
+ if width == height:
82
+ return pil_img
83
+ elif width > height:
84
+ result = Image.new(pil_img.mode, (width, width), background_color)
85
+ result.paste(pil_img, (0, (width - height) // 2))
86
+ return result
87
+ else:
88
+ result = Image.new(pil_img.mode, (height, height), background_color)
89
+ result.paste(pil_img, ((height - width) // 2, 0))
90
+ return result
91
+
92
+ def forward(self, image: List[Image.Image]):
93
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
94
+ with torch.inference_mode():
95
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
96
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
97
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
98
+
99
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
100
+
101
+ class QAlignVideoScorer(nn.Module):
102
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
103
+ super().__init__()
104
+ if model is None:
105
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
106
+ prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is"
107
+
108
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
109
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
110
+
111
+ self.tokenizer = tokenizer
112
+ self.model = model
113
+ self.image_processor = image_processor
114
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
115
+
116
+ def expand2square(self, pil_img, background_color):
117
+ width, height = pil_img.size
118
+ if width == height:
119
+ return pil_img
120
+ elif width > height:
121
+ result = Image.new(pil_img.mode, (width, width), background_color)
122
+ result.paste(pil_img, (0, (width - height) // 2))
123
+ return result
124
+ else:
125
+ result = Image.new(pil_img.mode, (height, height), background_color)
126
+ result.paste(pil_img, ((height - width) // 2, 0))
127
+ return result
128
+
129
+ def forward(self, video: List[List[Image.Image]]):
130
+ video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video]
131
+ with torch.inference_mode():
132
+ video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
133
+ output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
134
+ images=video_tensors)["logits"][:,-1, self.preferential_ids_]
135
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
136
+
137
+
138
+ if __name__ == "__main__":
139
+ import argparse
140
+
141
+ parser = argparse.ArgumentParser()
142
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
143
+ parser.add_argument("--device", type=str, default="cuda:0")
144
+ parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
145
+ parser.add_argument("--aesthetic", action="store_true")
146
+ parser.add_argument("--video", action="store_true")
147
+ args = parser.parse_args()
148
+
149
+ if args.video:
150
+ scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device)
151
+ print(scorer([load_video(args.img_path)]).tolist())
152
+ else:
153
+ scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device)
154
+ print(scorer([Image.open(args.img_path)]).tolist())
155
+
q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+
17
+ from scipy.stats import spearmanr, pearsonr
18
+
19
+ import json
20
+ from tqdm import tqdm
21
+ from collections import defaultdict
22
+
23
+ import os
24
+
25
+ def wa5(logits):
26
+ import numpy as np
27
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
28
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
29
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
30
+
31
+
32
+
33
+ def disable_torch_init():
34
+ """
35
+ Disable the redundant torch default initialization to accelerate model creation.
36
+ """
37
+ import torch
38
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
39
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
40
+
41
+
42
+ def load_video(video_file):
43
+ from decord import VideoReader
44
+ vr = VideoReader(video_file)
45
+
46
+ # Get video frame rate
47
+ fps = vr.get_avg_fps()
48
+
49
+ # Calculate frame indices for 1fps
50
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
51
+ frames = vr.get_batch(frame_indices).asnumpy()
52
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
53
+
54
+
55
+ def main(args):
56
+ # Model
57
+ disable_torch_init()
58
+
59
+ model_name = get_model_name_from_path(args.model_path)
60
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
61
+
62
+
63
+ import json
64
+
65
+
66
+ image_paths = [
67
+ #"playground/data/",
68
+ #"playground/data/",
69
+ "playground/data/KoNViD_1k_videos/",
70
+ "playground/data/maxwell/",
71
+
72
+ ]
73
+
74
+ json_prefix = "playground/data/test_jsons/"
75
+ jsons = [
76
+ #json_prefix + "test_lsvq.json",
77
+ #json_prefix + "test_lsvq_1080p.json",
78
+ json_prefix + "konvid.json",
79
+ json_prefix + "maxwell_test.json",
80
+ ]
81
+
82
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
83
+
84
+
85
+ conv_mode = "mplug_owl2"
86
+
87
+ inp = "How would you rate the quality of this video?"
88
+
89
+ conv = conv_templates[conv_mode].copy()
90
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
91
+ conv.append_message(conv.roles[0], inp)
92
+ image = None
93
+
94
+ conv.append_message(conv.roles[1], None)
95
+ prompt = conv.get_prompt() + " The quality of the video is"
96
+
97
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
98
+ print(toks)
99
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
100
+ print(ids_)
101
+
102
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
103
+
104
+ for image_path, json_ in zip(image_paths, jsons):
105
+ with open(json_) as f:
106
+ iqadata = json.load(f)
107
+ prs, gts = [], []
108
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
109
+ try:
110
+ try:
111
+ filename = llddata["img_path"]
112
+ except:
113
+ filename = llddata["image"]
114
+ llddata["logits"] = defaultdict(float)
115
+
116
+ image = load_video(image_path + filename)
117
+ def expand2square(pil_img, background_color):
118
+ width, height = pil_img.size
119
+ if width == height:
120
+ return pil_img
121
+ elif width > height:
122
+ result = Image.new(pil_img.mode, (width, width), background_color)
123
+ result.paste(pil_img, (0, (width - height) // 2))
124
+ return result
125
+ else:
126
+ result = Image.new(pil_img.mode, (height, height), background_color)
127
+ result.paste(pil_img, ((height - width) // 2, 0))
128
+ return result
129
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
130
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
131
+
132
+ if True:
133
+ with torch.inference_mode():
134
+ output_logits = model(input_ids,
135
+ images=[image_tensor])["logits"][:,-1]
136
+ for tok, id_ in zip(toks, ids_):
137
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
138
+ llddata["score"] = wa5(llddata["logits"])
139
+ # print(llddata)
140
+ prs.append(llddata["score"])
141
+ gts.append(llddata["gt_score"])
142
+ # print(llddata)
143
+ json_ = json_.replace("combined/", "combined-")
144
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
145
+ json.dump(llddata, wf)
146
+
147
+ if i > 0 and i % 200 == 0:
148
+ print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
149
+ except:
150
+ continue
151
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
152
+
153
+
154
+ if __name__ == "__main__":
155
+ parser = argparse.ArgumentParser()
156
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
157
+ parser.add_argument("--model-base", type=str, default=None)
158
+ parser.add_argument("--device", type=str, default="cuda:0")
159
+ parser.add_argument("--conv-mode", type=str, default=None)
160
+ parser.add_argument("--temperature", type=float, default=0.2)
161
+ parser.add_argument("--max-new-tokens", type=int, default=512)
162
+ parser.add_argument("--load-8bit", action="store_true")
163
+ parser.add_argument("--load-4bit", action="store_true")
164
+ parser.add_argument("--debug", action="store_true")
165
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
166
+ args = parser.parse_args()
167
+ main(args)
q_align/evaluate/__pycache__/scorer.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
q_align/evaluate/eval.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+ import json
17
+ from tqdm import tqdm
18
+ from collections import defaultdict
19
+
20
+ import os
21
+
22
+
23
+
24
+ def disable_torch_init():
25
+ """
26
+ Disable the redundant torch default initialization to accelerate model creation.
27
+ """
28
+ import torch
29
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
30
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
31
+
32
+
33
+ def load_image(image_file):
34
+ if image_file.startswith('http://') or image_file.startswith('https://'):
35
+ response = requests.get(image_file)
36
+ image = Image.open(BytesIO(response.content)).convert('RGB')
37
+ else:
38
+ image = Image.open(image_file).convert('RGB')
39
+ return image
40
+
41
+
42
+ def main(args):
43
+ # Model
44
+ disable_torch_init()
45
+
46
+ model_name = get_model_name_from_path(args.model_path)
47
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
48
+
49
+
50
+ import json
51
+
52
+
53
+ image_path = "playground/data/"
54
+
55
+ json_prefix = "playground/data/labels/mos_simple/"
56
+ jsons = [
57
+ json_prefix + "test_flive.json",
58
+ json_prefix + "combined/kadid_ref.json",
59
+ json_prefix + "combined/livec.json",
60
+ json_prefix + "test_koniq.json",
61
+ json_prefix + "test_spaq.json",
62
+ json_prefix + "combined/agi.json",
63
+ json_prefix + "combined/kadid.json",
64
+ ]
65
+
66
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
67
+
68
+
69
+ conv_mode = "mplug_owl2"
70
+
71
+ inp = "How would you rate the quality of this image?"
72
+
73
+ conv = conv_templates[conv_mode].copy()
74
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
75
+ conv.append_message(conv.roles[0], inp)
76
+ image = None
77
+
78
+ conv.append_message(conv.roles[1], None)
79
+ prompt = conv.get_prompt() + " The quality of the image is"
80
+
81
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
82
+ print(toks)
83
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
84
+ print(ids_)
85
+
86
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
87
+
88
+ for json_ in jsons:
89
+ with open(json_) as f:
90
+ iqadata = json.load(f)
91
+
92
+ image_tensors = []
93
+ batch_data = []
94
+
95
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
96
+ #print(f"Evaluating image {i}")
97
+ #print(prompt)
98
+ filename = llddata["image"]
99
+ llddata["logits"] = defaultdict(float)
100
+
101
+ image = load_image(image_path + filename)
102
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
103
+
104
+ image_tensors.append(image_tensor)
105
+ batch_data.append(llddata)
106
+
107
+ if i % 8 == 7 or i == len(iqadata) - 1:
108
+ with torch.inference_mode():
109
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
110
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
111
+
112
+ for j, xllddata in enumerate(batch_data):
113
+ for tok, id_ in zip(toks, ids_):
114
+ xllddata["logits"][tok] += output_logits[j,id_].item()
115
+ # print(llddata)
116
+ json_ = json_.replace("combined/", "combined-")
117
+ # print(f"results/mix-mplug-owl-2-boost_iqa_wu_v2/{json_.split('/')[-1]}")
118
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
119
+ json.dump(xllddata, wf)
120
+
121
+ image_tensors = []
122
+ batch_data = []
123
+
124
+
125
+ if __name__ == "__main__":
126
+ parser = argparse.ArgumentParser()
127
+ parser.add_argument("--model-path", type=str, default="q-future/q-align-koniq-spaq-v0")
128
+ parser.add_argument("--model-base", type=str, default=None)
129
+ parser.add_argument("--device", type=str, default="cuda")
130
+ parser.add_argument("--conv-mode", type=str, default=None)
131
+ parser.add_argument("--temperature", type=float, default=0.2)
132
+ parser.add_argument("--max-new-tokens", type=int, default=512)
133
+ parser.add_argument("--load-8bit", action="store_true")
134
+ parser.add_argument("--load-4bit", action="store_true")
135
+ parser.add_argument("--debug", action="store_true")
136
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
137
+ args = parser.parse_args()
138
+ main(args)
q_align/evaluate/iaa_eval.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+ from PIL import ImageFile
11
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
12
+
13
+ import requests
14
+ from PIL import Image
15
+ from io import BytesIO
16
+ from transformers import TextStreamer
17
+
18
+ from scipy.stats import spearmanr, pearsonr
19
+
20
+
21
+ import json
22
+ from tqdm import tqdm
23
+ from collections import defaultdict
24
+
25
+ import os
26
+
27
+ def wa5(logits):
28
+ import numpy as np
29
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
30
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
31
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
32
+
33
+
34
+
35
+
36
+ def disable_torch_init():
37
+ """
38
+ Disable the redundant torch default initialization to accelerate model creation.
39
+ """
40
+ import torch
41
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
42
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
43
+
44
+
45
+ def load_image(image_file):
46
+ if image_file.startswith('http://') or image_file.startswith('https://'):
47
+ response = requests.get(image_file)
48
+ image = Image.open(BytesIO(response.content)).convert('RGB')
49
+ else:
50
+ image = Image.open(image_file).convert('RGB')
51
+ return image
52
+
53
+
54
+ def main(args):
55
+ # Model
56
+ disable_torch_init()
57
+
58
+ model_name = get_model_name_from_path(args.model_path)
59
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
60
+
61
+
62
+ import json
63
+
64
+
65
+ image_path = "playground/data/"
66
+
67
+
68
+ json_prefix = "playground/data/test_jsons/"
69
+ jsons = [
70
+ json_prefix + "test_ava.json",
71
+ ]
72
+
73
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
74
+
75
+
76
+ conv_mode = "mplug_owl2"
77
+
78
+ inp = "How would you rate the aesthetics of this image?"
79
+
80
+ conv = conv_templates[conv_mode].copy()
81
+ inp = DEFAULT_IMAGE_TOKEN + inp
82
+ conv.append_message(conv.roles[0], inp)
83
+ image = None
84
+
85
+ conv.append_message(conv.roles[1], None)
86
+ prompt = conv.get_prompt() + " The aesthetics of the image is"
87
+
88
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
89
+ print(toks)
90
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
91
+ print(ids_)
92
+
93
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
94
+
95
+ for json_ in jsons:
96
+ with open(json_) as f:
97
+ iqadata = json.load(f)
98
+
99
+ image_tensors = []
100
+ batch_data = []
101
+ prs, gts = [], []
102
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
103
+ filename = llddata["image"]
104
+ llddata["logits"] = defaultdict(float)
105
+
106
+
107
+
108
+ image = load_image(image_path + filename)
109
+ def expand2square(pil_img, background_color):
110
+ width, height = pil_img.size
111
+ if width == height:
112
+ return pil_img
113
+ elif width > height:
114
+ result = Image.new(pil_img.mode, (width, width), background_color)
115
+ result.paste(pil_img, (0, (width - height) // 2))
116
+ return result
117
+ else:
118
+ result = Image.new(pil_img.mode, (height, height), background_color)
119
+ result.paste(pil_img, ((height - width) // 2, 0))
120
+ return result
121
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
122
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
123
+
124
+ image_tensors.append(image_tensor)
125
+ batch_data.append(llddata)
126
+
127
+ if i % 8 == 7 or i == len(iqadata) - 1:
128
+ with torch.inference_mode():
129
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
130
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
131
+
132
+ for j, xllddata in enumerate(batch_data):
133
+ for tok, id_ in zip(toks, ids_):
134
+ xllddata["logits"][tok] += output_logits[j,id_].item()
135
+ xllddata["score"] = wa5(xllddata["logits"])
136
+ # print(llddata)
137
+ prs.append(xllddata["score"])
138
+ gts.append(xllddata["gt_score"])
139
+ json_ = json_.replace("combined/", "combined-")
140
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
141
+ json.dump(xllddata, wf)
142
+
143
+ image_tensors = []
144
+ batch_data = []
145
+
146
+ #if i > 0 and i % 200 == 0:
147
+ # print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
148
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
149
+
150
+
151
+ if __name__ == "__main__":
152
+ parser = argparse.ArgumentParser()
153
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
154
+ parser.add_argument("--model-base", type=str, default=None)
155
+ parser.add_argument("--device", type=str, default="cuda:0")
156
+ parser.add_argument("--conv-mode", type=str, default=None)
157
+ parser.add_argument("--temperature", type=float, default=0.2)
158
+ parser.add_argument("--max-new-tokens", type=int, default=512)
159
+ parser.add_argument("--load-8bit", action="store_true")
160
+ parser.add_argument("--load-4bit", action="store_true")
161
+ parser.add_argument("--debug", action="store_true")
162
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
163
+ args = parser.parse_args()
164
+ main(args)
q_align/evaluate/iqa4vqa_eval.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+ from decord import VideoReader
17
+
18
+
19
+ import json
20
+ from tqdm import tqdm
21
+ from collections import defaultdict
22
+
23
+ import os
24
+
25
+
26
+
27
+
28
+ def disable_torch_init():
29
+ """
30
+ Disable the redundant torch default initialization to accelerate model creation.
31
+ """
32
+ import torch
33
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
34
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
35
+
36
+
37
+ def load_video(video_file):
38
+ vr = VideoReader(video_file)
39
+
40
+ # Get video frame rate
41
+ fps = vr.get_avg_fps()
42
+
43
+ # Calculate frame indices for 1fps
44
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
45
+ frames = vr.get_batch(frame_indices).asnumpy()
46
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
47
+
48
+
49
+ def main(args):
50
+ # Model
51
+ disable_torch_init()
52
+
53
+ model_name = get_model_name_from_path(args.model_path)
54
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
55
+
56
+
57
+ import json
58
+
59
+
60
+ image_paths = [
61
+ "playground/data/",
62
+ "playground/data/",
63
+ "playground/data/KoNViD_1k_videos/",
64
+ "playground/data/maxwell/",
65
+ ]
66
+
67
+ json_prefix = "playground/data/test_jsons/"
68
+ jsons = [
69
+ json_prefix + "test_lsvq.json",
70
+ json_prefix + "test_lsvq_1080p.json",
71
+ json_prefix + "konvid.json",
72
+ json_prefix + "maxwell_test.json",
73
+ ]
74
+
75
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
76
+
77
+
78
+ conv_mode = "mplug_owl2"
79
+
80
+ inp = "How would you rate the quality of this image?"
81
+
82
+ conv = conv_templates[conv_mode].copy()
83
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
84
+ conv.append_message(conv.roles[0], inp)
85
+ image = None
86
+
87
+ conv.append_message(conv.roles[1], None)
88
+ prompt = conv.get_prompt() + " The quality of the image is"
89
+
90
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
91
+ print(toks)
92
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
93
+ print(ids_)
94
+
95
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
96
+
97
+ for image_path, json_ in zip(image_paths, jsons):
98
+ with open(json_) as f:
99
+ iqadata = json.load(f)
100
+ try:
101
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
102
+ filename = llddata["img_path"]
103
+ llddata["logits"] = defaultdict(float)
104
+
105
+ image = load_video(image_path + filename)
106
+ def expand2square(pil_img, background_color):
107
+ width, height = pil_img.size
108
+ if width == height:
109
+ return pil_img
110
+ elif width > height:
111
+ result = Image.new(pil_img.mode, (width, width), background_color)
112
+ result.paste(pil_img, (0, (width - height) // 2))
113
+ return result
114
+ else:
115
+ result = Image.new(pil_img.mode, (height, height), background_color)
116
+ result.paste(pil_img, ((height - width) // 2, 0))
117
+ return result
118
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
119
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
120
+
121
+
122
+ if True:
123
+ with torch.inference_mode():
124
+ output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
125
+ images=image_tensor)["logits"][:,-1]
126
+
127
+ for tok, id_ in zip(toks, ids_):
128
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
129
+ # print(llddata)
130
+ json_ = json_.replace("combined/", "combined-")
131
+ with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
132
+ json.dump(llddata, wf)
133
+ except:
134
+ continue
135
+
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model-path", type=str, default="q-future/q-align-image")
140
+ parser.add_argument("--model-base", type=str, default=None)
141
+ parser.add_argument("--device", type=str, default="cuda:0")
142
+ parser.add_argument("--conv-mode", type=str, default=None)
143
+ parser.add_argument("--temperature", type=float, default=0.2)
144
+ parser.add_argument("--max-new-tokens", type=int, default=512)
145
+ parser.add_argument("--load-8bit", action="store_true")
146
+ parser.add_argument("--load-4bit", action="store_true")
147
+ parser.add_argument("--debug", action="store_true")
148
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
149
+ args = parser.parse_args()
150
+ main(args)
q_align/evaluate/iqa_eval.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+ import json
17
+ from tqdm import tqdm
18
+ from collections import defaultdict
19
+
20
+ import os
21
+
22
+
23
+
24
+
25
+ def disable_torch_init():
26
+ """
27
+ Disable the redundant torch default initialization to accelerate model creation.
28
+ """
29
+ import torch
30
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
31
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
32
+
33
+
34
+ def load_image(image_file):
35
+ if image_file.startswith('http://') or image_file.startswith('https://'):
36
+ response = requests.get(image_file)
37
+ image = Image.open(BytesIO(response.content)).convert('RGB')
38
+ else:
39
+ image = Image.open(image_file).convert('RGB')
40
+ return image
41
+
42
+
43
+ def main(args):
44
+ # Model
45
+ disable_torch_init()
46
+
47
+ model_name = get_model_name_from_path(args.model_path)
48
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
49
+
50
+
51
+ import json
52
+
53
+
54
+ image_path = "playground/data/"
55
+
56
+
57
+ json_prefix = "playground/data/test_jsons/"
58
+ jsons = [
59
+ json_prefix + "test_imagerewarddb.json",
60
+ json_prefix + "test_koniq.json",
61
+ json_prefix + "test_spaq.json",
62
+ json_prefix + "test_kadid.json",
63
+ json_prefix + "livec.json",
64
+ json_prefix + "agi.json",
65
+ json_prefix + "live.json",
66
+ json_prefix + "csiq.json",
67
+ ]
68
+
69
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
70
+
71
+
72
+ conv_mode = "mplug_owl2"
73
+
74
+ inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?"
75
+
76
+ conv = conv_templates[conv_mode].copy()
77
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
78
+ conv.append_message(conv.roles[0], inp)
79
+ image = None
80
+
81
+ conv.append_message(conv.roles[1], None)
82
+ prompt = conv.get_prompt() + " The quality of the image is"
83
+
84
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
85
+ print(toks)
86
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
87
+ print(ids_)
88
+
89
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
90
+
91
+ for json_ in jsons:
92
+ with open(json_) as f:
93
+ iqadata = json.load(f)
94
+
95
+ image_tensors = []
96
+ batch_data = []
97
+
98
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
99
+ if True:
100
+ try:
101
+ filename = llddata["image"]
102
+ except:
103
+ filename = llddata["img_path"]
104
+ llddata["logits"] = defaultdict(float)
105
+
106
+ image = load_image(image_path + filename)
107
+ def expand2square(pil_img, background_color):
108
+ width, height = pil_img.size
109
+ if width == height:
110
+ return pil_img
111
+ elif width > height:
112
+ result = Image.new(pil_img.mode, (width, width), background_color)
113
+ result.paste(pil_img, (0, (width - height) // 2))
114
+ return result
115
+ else:
116
+ result = Image.new(pil_img.mode, (height, height), background_color)
117
+ result.paste(pil_img, ((height - width) // 2, 0))
118
+ return result
119
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
120
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
121
+
122
+ image_tensors.append(image_tensor)
123
+ batch_data.append(llddata)
124
+
125
+ if i % 8 == 7 or i == len(iqadata) - 1:
126
+ with torch.inference_mode():
127
+ output_logits = model(input_ids.repeat(len(image_tensors), 1),
128
+ images=torch.cat(image_tensors, 0))["logits"][:,-1]
129
+
130
+ for j, xllddata in enumerate(batch_data):
131
+ for tok, id_ in zip(toks, ids_):
132
+ xllddata["logits"][tok] += output_logits[j,id_].item()
133
+ # print(llddata)
134
+ json_ = json_.replace("combined/", "combined-")
135
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
136
+ json.dump(xllddata, wf)
137
+
138
+ image_tensors = []
139
+ batch_data = []
140
+
141
+
142
+
143
+ if __name__ == "__main__":
144
+ parser = argparse.ArgumentParser()
145
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
146
+ parser.add_argument("--model-base", type=str, default=None)
147
+ parser.add_argument("--device", type=str, default="cuda:0")
148
+ parser.add_argument("--conv-mode", type=str, default=None)
149
+ parser.add_argument("--temperature", type=float, default=0.2)
150
+ parser.add_argument("--max-new-tokens", type=int, default=512)
151
+ parser.add_argument("--load-8bit", action="store_true")
152
+ parser.add_argument("--load-4bit", action="store_true")
153
+ parser.add_argument("--debug", action="store_true")
154
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
155
+ args = parser.parse_args()
156
+ main(args)
q_align/evaluate/scorer.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from typing import List
7
+
8
+ from q_align.model.builder import load_pretrained_model
9
+
10
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
11
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
12
+
13
+ def load_video(video_file):
14
+ from decord import VideoReader
15
+ vr = VideoReader(video_file)
16
+
17
+ # Get video frame rate
18
+ fps = vr.get_avg_fps()
19
+
20
+ # Calculate frame indices for 1fps
21
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
22
+ frames = vr.get_batch(frame_indices).asnumpy()
23
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
24
+
25
+
26
+ class QAlignScorer(nn.Module):
27
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
28
+ super().__init__()
29
+ if model is None:
30
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
31
+ prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
32
+
33
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
34
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
35
+
36
+ self.tokenizer = tokenizer
37
+ self.model = model
38
+ self.image_processor = image_processor
39
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
40
+
41
+ def expand2square(self, pil_img, background_color):
42
+ width, height = pil_img.size
43
+ if width == height:
44
+ return pil_img
45
+ elif width > height:
46
+ result = Image.new(pil_img.mode, (width, width), background_color)
47
+ result.paste(pil_img, (0, (width - height) // 2))
48
+ return result
49
+ else:
50
+ result = Image.new(pil_img.mode, (height, height), background_color)
51
+ result.paste(pil_img, ((height - width) // 2, 0))
52
+ return result
53
+
54
+ def forward(self, image: List[Image.Image]):
55
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
56
+ with torch.inference_mode():
57
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
58
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
59
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
60
+
61
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
62
+
63
+
64
+ class QAlignAestheticScorer(nn.Module):
65
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
66
+ super().__init__()
67
+ if model is None:
68
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
69
+ prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is"
70
+
71
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
72
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
73
+
74
+ self.tokenizer = tokenizer
75
+ self.model = model
76
+ self.image_processor = image_processor
77
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
78
+
79
+ def expand2square(self, pil_img, background_color):
80
+ width, height = pil_img.size
81
+ if width == height:
82
+ return pil_img
83
+ elif width > height:
84
+ result = Image.new(pil_img.mode, (width, width), background_color)
85
+ result.paste(pil_img, (0, (width - height) // 2))
86
+ return result
87
+ else:
88
+ result = Image.new(pil_img.mode, (height, height), background_color)
89
+ result.paste(pil_img, ((height - width) // 2, 0))
90
+ return result
91
+
92
+ def forward(self, image: List[Image.Image]):
93
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
94
+ with torch.inference_mode():
95
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
96
+ output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
97
+ images=image_tensor)["logits"][:,-1, self.preferential_ids_]
98
+
99
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
100
+
101
+ class QAlignVideoScorer(nn.Module):
102
+ def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
103
+ super().__init__()
104
+ if model is None:
105
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
106
+ prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is"
107
+
108
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
109
+ self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
110
+
111
+ self.tokenizer = tokenizer
112
+ self.model = model
113
+ self.image_processor = image_processor
114
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
115
+
116
+ def expand2square(self, pil_img, background_color):
117
+ width, height = pil_img.size
118
+ if width == height:
119
+ return pil_img
120
+ elif width > height:
121
+ result = Image.new(pil_img.mode, (width, width), background_color)
122
+ result.paste(pil_img, (0, (width - height) // 2))
123
+ return result
124
+ else:
125
+ result = Image.new(pil_img.mode, (height, height), background_color)
126
+ result.paste(pil_img, ((height - width) // 2, 0))
127
+ return result
128
+
129
+ def forward(self, video: List[List[Image.Image]]):
130
+ video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video]
131
+ with torch.inference_mode():
132
+ video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
133
+ output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
134
+ images=video_tensors)["logits"][:,-1, self.preferential_ids_]
135
+ return torch.softmax(output_logits, -1) #@ self.weight_tensor
136
+
137
+
138
+ if __name__ == "__main__":
139
+ import argparse
140
+
141
+ parser = argparse.ArgumentParser()
142
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
143
+ parser.add_argument("--device", type=str, default="cuda:0")
144
+ parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
145
+ parser.add_argument("--aesthetic", action="store_true")
146
+ parser.add_argument("--video", action="store_true")
147
+ args = parser.parse_args()
148
+
149
+ if args.video:
150
+ scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device)
151
+ print(scorer([load_video(args.img_path)]).tolist())
152
+ else:
153
+ scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device)
154
+ print(scorer([Image.open(args.img_path)]).tolist())
155
+
q_align/evaluate/vqa_eval.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from q_align.conversation import conv_templates, SeparatorStyle
6
+ from q_align.model.builder import load_pretrained_model
7
+ from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+
17
+ from scipy.stats import spearmanr, pearsonr
18
+
19
+ import json
20
+ from tqdm import tqdm
21
+ from collections import defaultdict
22
+
23
+ import os
24
+
25
+ def wa5(logits):
26
+ import numpy as np
27
+ logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
28
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
29
+ return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
30
+
31
+
32
+
33
+ def disable_torch_init():
34
+ """
35
+ Disable the redundant torch default initialization to accelerate model creation.
36
+ """
37
+ import torch
38
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
39
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
40
+
41
+
42
+ def load_video(video_file):
43
+ from decord import VideoReader
44
+ vr = VideoReader(video_file)
45
+
46
+ # Get video frame rate
47
+ fps = vr.get_avg_fps()
48
+
49
+ # Calculate frame indices for 1fps
50
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
51
+ frames = vr.get_batch(frame_indices).asnumpy()
52
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
53
+
54
+
55
+ def main(args):
56
+ # Model
57
+ disable_torch_init()
58
+
59
+ model_name = get_model_name_from_path(args.model_path)
60
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
61
+
62
+
63
+ import json
64
+
65
+
66
+ image_paths = [
67
+ #"playground/data/",
68
+ #"playground/data/",
69
+ "playground/data/KoNViD_1k_videos/",
70
+ "playground/data/maxwell/",
71
+
72
+ ]
73
+
74
+ json_prefix = "playground/data/test_jsons/"
75
+ jsons = [
76
+ #json_prefix + "test_lsvq.json",
77
+ #json_prefix + "test_lsvq_1080p.json",
78
+ json_prefix + "konvid.json",
79
+ json_prefix + "maxwell_test.json",
80
+ ]
81
+
82
+ os.makedirs(f"results/{args.model_path}/", exist_ok=True)
83
+
84
+
85
+ conv_mode = "mplug_owl2"
86
+
87
+ inp = "How would you rate the quality of this video?"
88
+
89
+ conv = conv_templates[conv_mode].copy()
90
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
91
+ conv.append_message(conv.roles[0], inp)
92
+ image = None
93
+
94
+ conv.append_message(conv.roles[1], None)
95
+ prompt = conv.get_prompt() + " The quality of the video is"
96
+
97
+ toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
98
+ print(toks)
99
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
100
+ print(ids_)
101
+
102
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
103
+
104
+ for image_path, json_ in zip(image_paths, jsons):
105
+ with open(json_) as f:
106
+ iqadata = json.load(f)
107
+ prs, gts = [], []
108
+ for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
109
+ try:
110
+ try:
111
+ filename = llddata["img_path"]
112
+ except:
113
+ filename = llddata["image"]
114
+ llddata["logits"] = defaultdict(float)
115
+
116
+ image = load_video(image_path + filename)
117
+ def expand2square(pil_img, background_color):
118
+ width, height = pil_img.size
119
+ if width == height:
120
+ return pil_img
121
+ elif width > height:
122
+ result = Image.new(pil_img.mode, (width, width), background_color)
123
+ result.paste(pil_img, (0, (width - height) // 2))
124
+ return result
125
+ else:
126
+ result = Image.new(pil_img.mode, (height, height), background_color)
127
+ result.paste(pil_img, ((height - width) // 2, 0))
128
+ return result
129
+ image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
130
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
131
+
132
+ if True:
133
+ with torch.inference_mode():
134
+ output_logits = model(input_ids,
135
+ images=[image_tensor])["logits"][:,-1]
136
+ for tok, id_ in zip(toks, ids_):
137
+ llddata["logits"][tok] += output_logits.mean(0)[id_].item()
138
+ llddata["score"] = wa5(llddata["logits"])
139
+ # print(llddata)
140
+ prs.append(llddata["score"])
141
+ gts.append(llddata["gt_score"])
142
+ # print(llddata)
143
+ json_ = json_.replace("combined/", "combined-")
144
+ with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
145
+ json.dump(llddata, wf)
146
+
147
+ if i > 0 and i % 200 == 0:
148
+ print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
149
+ except:
150
+ continue
151
+ print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
152
+
153
+
154
+ if __name__ == "__main__":
155
+ parser = argparse.ArgumentParser()
156
+ parser.add_argument("--model-path", type=str, default="q-future/one-align")
157
+ parser.add_argument("--model-base", type=str, default=None)
158
+ parser.add_argument("--device", type=str, default="cuda:0")
159
+ parser.add_argument("--conv-mode", type=str, default=None)
160
+ parser.add_argument("--temperature", type=float, default=0.2)
161
+ parser.add_argument("--max-new-tokens", type=int, default=512)
162
+ parser.add_argument("--load-8bit", action="store_true")
163
+ parser.add_argument("--load-4bit", action="store_true")
164
+ parser.add_argument("--debug", action="store_true")
165
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
166
+ args = parser.parse_args()
167
+ main(args)
q_align/mm_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from q_align.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
8
+ from icecream import ic
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+
15
+ def expand2square(pil_img, background_color):
16
+ width, height = pil_img.size
17
+ if width == height:
18
+ return pil_img
19
+ elif width > height:
20
+ result = Image.new(pil_img.mode, (width, width), background_color)
21
+ result.paste(pil_img, (0, (width - height) // 2))
22
+ return result
23
+ else:
24
+ result = Image.new(pil_img.mode, (height, height), background_color)
25
+ result.paste(pil_img, ((height - width) // 2, 0))
26
+ return result
27
+
28
+
29
+ def process_images(images, image_processor, model_cfg=None):
30
+ if model_cfg is not None:
31
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32
+ else:
33
+ image_aspect_ratio = 'resize'
34
+ new_images = []
35
+ if image_aspect_ratio == 'pad':
36
+ for image in images:
37
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
38
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
39
+ new_images.append(image)
40
+ elif image_aspect_ratio == 'resize':
41
+ for image in images:
42
+ max_edge = max(image.size)
43
+ image = image.resize((max_edge, max_edge))
44
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
45
+ new_images.append(image)
46
+ else:
47
+ return image_processor(images, return_tensors='pt')['pixel_values']
48
+ if all(x.shape == new_images[0].shape for x in new_images):
49
+ new_images = torch.stack(new_images, dim=0)
50
+ return new_images
51
+
52
+
53
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
54
+ prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
55
+
56
+ def insert_separator(X, sep):
57
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
58
+
59
+ input_ids = []
60
+ offset = 0
61
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
62
+ offset = 1
63
+ input_ids.append(prompt_chunks[0][0])
64
+
65
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
66
+ input_ids.extend(x[offset:])
67
+
68
+ if return_tensors is not None:
69
+ if return_tensors == 'pt':
70
+ return torch.tensor(input_ids, dtype=torch.long)
71
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
72
+ return input_ids
73
+
74
+
75
+ def get_model_name_from_path(model_path):
76
+ model_path = model_path.strip("/")
77
+ model_paths = model_path.split("/")
78
+ if model_paths[-1].startswith('checkpoint-'):
79
+ return model_paths[-2] + "_" + model_paths[-1]
80
+ else:
81
+ return model_paths[-1]
82
+
83
+
84
+
85
+
86
+ class KeywordsStoppingCriteria(StoppingCriteria):
87
+ def __init__(self, keywords, tokenizer, input_ids):
88
+ self.keywords = keywords
89
+ self.keyword_ids = []
90
+ self.max_keyword_len = 0
91
+ for keyword in keywords:
92
+ cur_keyword_ids = tokenizer(keyword).input_ids
93
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
94
+ cur_keyword_ids = cur_keyword_ids[1:]
95
+ if len(cur_keyword_ids) > self.max_keyword_len:
96
+ self.max_keyword_len = len(cur_keyword_ids)
97
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
98
+ self.tokenizer = tokenizer
99
+ self.start_len = input_ids.shape[1]
100
+
101
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
103
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
104
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
105
+ for keyword_id in self.keyword_ids:
106
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
107
+ return True
108
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
109
+ for keyword in self.keywords:
110
+ if keyword in outputs:
111
+ return True
112
+ return False
q_align/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
2
+ from .configuration_mplug_owl2 import MPLUGOwl2Config
q_align/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (297 Bytes). View file
 
q_align/model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (318 Bytes). View file
 
q_align/model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (268 Bytes). View file
 
q_align/model/__pycache__/builder.cpython-310.pyc ADDED
Binary file (3.8 kB). View file
 
q_align/model/__pycache__/builder.cpython-311.pyc ADDED
Binary file (7.42 kB). View file
 
q_align/model/__pycache__/builder.cpython-39.pyc ADDED
Binary file (3.81 kB). View file
 
q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc ADDED
Binary file (16.7 kB). View file
 
q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc ADDED
Binary file (13.2 kB). View file
 
q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc ADDED
Binary file (7.59 kB). View file
 
q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc ADDED
Binary file (11.2 kB). View file
 
q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc ADDED
Binary file (7.49 kB). View file
 
q_align/model/__pycache__/modeling_llama2.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
q_align/model/__pycache__/modeling_llama2.cpython-311.pyc ADDED
Binary file (24.3 kB). View file
 
q_align/model/__pycache__/modeling_llama2.cpython-39.pyc ADDED
Binary file (13.1 kB). View file
 
q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc ADDED
Binary file (9.69 kB). View file
 
q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc ADDED
Binary file (21 kB). View file
 
q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc ADDED
Binary file (9.63 kB). View file