Spaces:
Runtime error
Runtime error
haoning.wu
commited on
Commit
·
e63f3e2
1
Parent(s):
a23f4af
Scorer Starts
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +117 -1
- q_align/.ipynb_checkpoints/utils-checkpoint.py +128 -0
- q_align/__init__.py +1 -0
- q_align/__pycache__/__init__.cpython-310.pyc +0 -0
- q_align/__pycache__/__init__.cpython-311.pyc +0 -0
- q_align/__pycache__/__init__.cpython-39.pyc +0 -0
- q_align/__pycache__/constants.cpython-310.pyc +0 -0
- q_align/__pycache__/constants.cpython-311.pyc +0 -0
- q_align/__pycache__/constants.cpython-39.pyc +0 -0
- q_align/__pycache__/conversation.cpython-310.pyc +0 -0
- q_align/__pycache__/conversation.cpython-311.pyc +0 -0
- q_align/__pycache__/conversation.cpython-39.pyc +0 -0
- q_align/__pycache__/mm_utils.cpython-310.pyc +0 -0
- q_align/__pycache__/mm_utils.cpython-311.pyc +0 -0
- q_align/__pycache__/mm_utils.cpython-39.pyc +0 -0
- q_align/__pycache__/utils.cpython-311.pyc +0 -0
- q_align/constants.py +9 -0
- q_align/conversation.py +301 -0
- q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py +164 -0
- q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py +150 -0
- q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py +156 -0
- q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py +155 -0
- q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py +167 -0
- q_align/evaluate/__pycache__/scorer.cpython-311.pyc +0 -0
- q_align/evaluate/eval.py +138 -0
- q_align/evaluate/iaa_eval.py +164 -0
- q_align/evaluate/iqa4vqa_eval.py +150 -0
- q_align/evaluate/iqa_eval.py +156 -0
- q_align/evaluate/scorer.py +155 -0
- q_align/evaluate/vqa_eval.py +167 -0
- q_align/mm_utils.py +112 -0
- q_align/model/__init__.py +2 -0
- q_align/model/__pycache__/__init__.cpython-310.pyc +0 -0
- q_align/model/__pycache__/__init__.cpython-311.pyc +0 -0
- q_align/model/__pycache__/__init__.cpython-39.pyc +0 -0
- q_align/model/__pycache__/builder.cpython-310.pyc +0 -0
- q_align/model/__pycache__/builder.cpython-311.pyc +0 -0
- q_align/model/__pycache__/builder.cpython-39.pyc +0 -0
- q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc +0 -0
- q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc +0 -0
- q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc +0 -0
- q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc +0 -0
- q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc +0 -0
- q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc +0 -0
- q_align/model/__pycache__/modeling_llama2.cpython-310.pyc +0 -0
- q_align/model/__pycache__/modeling_llama2.cpython-311.pyc +0 -0
- q_align/model/__pycache__/modeling_llama2.cpython-39.pyc +0 -0
- q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc +0 -0
- q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc +0 -0
- q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc +0 -0
app.py
CHANGED
@@ -1,3 +1,119 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import argparse
|
4 |
+
import datetime
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import requests
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
from q_align.model.builder import load_pretrained_model
|
14 |
+
|
15 |
+
from q_align.conversation import (default_conversation, conv_templates,
|
16 |
+
SeparatorStyle)
|
17 |
+
from q_align.constants import LOGDIR
|
18 |
+
from q_align.utils import (build_logger, server_error_msg,
|
19 |
+
violates_moderation, moderation_msg)
|
20 |
+
|
21 |
+
from q_align.evaluate.scorer import QAlignScorer, QAlignAestheticScorer, QAlignVideoScorer
|
22 |
+
|
23 |
+
import gradio as gr
|
24 |
+
|
25 |
+
def load_video(video_file):
|
26 |
+
from decord import VideoReader
|
27 |
+
vr = VideoReader(video_file)
|
28 |
+
|
29 |
+
# Get video frame rate
|
30 |
+
fps = vr.get_avg_fps()
|
31 |
+
|
32 |
+
# Calculate frame indices for 1fps
|
33 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
34 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
35 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
36 |
+
|
37 |
+
|
38 |
+
pretrained="q-future/one-align"
|
39 |
+
device="cuda:0"
|
40 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
41 |
+
|
42 |
+
iqa_scorer = QAlignScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
|
43 |
+
iaa_scorer = QAlignAestheticScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
|
44 |
+
vqa_scorer = QAlignVideoScorer(tokenizer=tokenizer, model=model, image_processor=image_processor)
|
45 |
+
|
46 |
+
scorers = {"Image Aesthetics (IAA)": iaa_scorer, "Image Quality (IQA)": iqa_scorer, "Video Quality (VQA)": vqa_scorer}
|
47 |
+
|
48 |
+
LEVELS = ["excellent (5)", "good (4)", "fair (3)", "poor (2)", "bad (1)"]
|
49 |
+
scores = [5,4,3,2,1]
|
50 |
+
def image_classifier(input_img, input_vid, scorer_type):
|
51 |
+
if scorer_type is None:
|
52 |
+
scorer_type = "Image Quality (IQA)"
|
53 |
+
this_scorer = scorers[scorer_type]
|
54 |
+
if input_vid is not None:
|
55 |
+
input_ = load_video(input_vid)
|
56 |
+
elif input_img is not None:
|
57 |
+
input_ = [input_img]
|
58 |
+
if "Video" in scorer_type:
|
59 |
+
input_ = [input_]
|
60 |
+
probs = this_scorer(input_).mean(0).tolist()
|
61 |
+
prob_dict = {LEVEL: prob for LEVEL, prob in zip(LEVELS, probs)}
|
62 |
+
score = sum([prob * score for score, prob in zip(scores, probs)])
|
63 |
+
return prob_dict, score
|
64 |
+
|
65 |
+
title_markdown = ("""
|
66 |
+
|
67 |
+
<h1 align="center">Q-Align: Teaching LMMs for Visual Scoring via Discrete Text-Defined Levels</h1>
|
68 |
+
|
69 |
+
<h3 align="center"> One Unified Model for Visual scoring. </h3>
|
70 |
+
|
71 |
+
<h5 align="center">
|
72 |
+
<a href="https://teowu.github.io/" target="_blank">Haoning Wu</a><sup>1</sup><sup>*</sup><sup>+</sup>,
|
73 |
+
<a href="https://github.com/zzc-1998" target="_blank">Zicheng Zhang</a><sup>2</sup><sup>*</sup>,
|
74 |
+
<a href="https://sites.google.com/view/r-panda" target="_blank">Weixia Zhang</a><sup>2</sup>,
|
75 |
+
<a href="https://chaofengc.github.io" target="_blank">Chaofeng Chen</a><sup>1</sup>,
|
76 |
+
<a href="https://liaoliang92.github.io" target="_blank">Liang Liao</a><sup>1</sup>,
|
77 |
+
<a href="https://github.com/lcysyzxdxc" target="_blank">Chunyi Li</a><sup>2</sup>,
|
78 |
+
</h5>
|
79 |
+
|
80 |
+
|
81 |
+
<h5 align="center">
|
82 |
+
<a href="https://github.com/YixuanGao98" target="_blank">Yixuan Gao</a><sup>2</sup>,
|
83 |
+
<a href="https://github.com/AnnanWangDaniel" target="_blank">Annan Wang</a><sup>1</sup>,
|
84 |
+
<a href="https://github.com/ZhangErliCarl/" target="_blank">Erli Zhang</a><sup>1</sup>,
|
85 |
+
<a href="https://wenxiusun.com" target="_blank">Wenxiu Sun</a><sup>3</sup>,
|
86 |
+
<a href="https://scholar.google.com/citations?user=uT9CtPYAAAAJ&hl=en" target="_blank">Qiong Yan</a><sup>3</sup>,
|
87 |
+
<a href="https://sites.google.com/site/minxiongkuo/" target="_blank">Xiongkuo Min</a><sup>2</sup>,
|
88 |
+
<a href="https://ee.sjtu.edu.cn/en/FacultyDetail.aspx?id=24&infoid=153&flag=153" target="_blank">Guangtao Zhai</a><sup>2</sup><sup>#</sup>,
|
89 |
+
<a href="https://personal.ntu.edu.sg/wslin/Home.html" target="_blank">Weisi Lin</a><sup>1</sup><sup>#</sup>
|
90 |
+
</h5>
|
91 |
+
|
92 |
+
<h5 align="center">
|
93 |
+
<sup>1</sup>Nanyang Technological University, <sup>2</sup>Shanghai Jiao Tong University, <sup>3</sup>Sensetime Research
|
94 |
+
</h5>
|
95 |
+
<h5 align="center">
|
96 |
+
<sup>*</sup>Equal contribution. <sup>+</sup>Project Lead. <sup>#</sup>Corresponding author(s).
|
97 |
+
</h5>
|
98 |
+
|
99 |
+
<h4 align="center"> If you like the OneScorer, please give us a star ✨ on <a href='https://github.com/Q-Future/Q-Align'>GitHub</a> for latest update. </h4>
|
100 |
+
|
101 |
+
<h5 align="center">
|
102 |
+
<div style="display:flex; gap: 0.25rem;" align="center">
|
103 |
+
<a href='https://q-align.github.io'><img src='https://img.shields.io/badge/Homepage-green'></a>
|
104 |
+
<a href='https://github.com/Q-Future/Q-Align'><img src='https://img.shields.io/badge/Github-Code-blue'></a>
|
105 |
+
<a href="https://Q-Future.github.io/Q-Align/fig/Q_Align_v0_1_preview.pdf"><img src="https://img.shields.io/badge/Technical-Report-red"></a>
|
106 |
+
<a href='https://github.com/Q-Future/Q-Align/stargazers'><img src='https://img.shields.io/github/stars/Q-Future/Q-Align.svg?style=social'></a>
|
107 |
+
</div>
|
108 |
+
</h5>
|
109 |
+
|
110 |
+
""")
|
111 |
+
|
112 |
+
|
113 |
+
input_img = gr.Image(type='pil', label="Upload an Image")
|
114 |
+
input_vid = gr.Video(label="Upload a Video (will INGORE the image if a video is uploaded)", info="If a video is uploaded, the image uploaded will be ignored.")
|
115 |
+
|
116 |
+
labels = gr.Label(label="Probabilities of rating levels:")
|
117 |
+
number = gr.Number(label="Output score:", info="Range in [1,5]. Higher is better.")
|
118 |
+
demo = gr.Interface(fn=image_classifier, inputs=[input_img, input_vid, gr.Radio(["Image Aesthetics (IAA)", "Image Quality (IQA)", "Video Quality (VQA)"], label="Task", info="Which Scorer will you need?"),], outputs=[labels, number], title="OneScorer", description=title_markdown)
|
119 |
+
demo.launch(share=True)
|
q_align/.ipynb_checkpoints/utils-checkpoint.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import logging.handlers
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import requests
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
from q_align.constants import LOGDIR
|
12 |
+
|
13 |
+
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
14 |
+
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
|
15 |
+
|
16 |
+
handler = None
|
17 |
+
|
18 |
+
|
19 |
+
def build_logger(logger_name, logger_filename):
|
20 |
+
global handler
|
21 |
+
|
22 |
+
formatter = logging.Formatter(
|
23 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
24 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
25 |
+
)
|
26 |
+
|
27 |
+
# Set the format of root handlers
|
28 |
+
if not logging.getLogger().handlers:
|
29 |
+
logging.basicConfig(level=logging.INFO)
|
30 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
31 |
+
|
32 |
+
# Redirect stdout and stderr to loggers
|
33 |
+
stdout_logger = logging.getLogger("stdout")
|
34 |
+
stdout_logger.setLevel(logging.INFO)
|
35 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
36 |
+
sys.stdout = sl
|
37 |
+
|
38 |
+
stderr_logger = logging.getLogger("stderr")
|
39 |
+
stderr_logger.setLevel(logging.ERROR)
|
40 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
41 |
+
sys.stderr = sl
|
42 |
+
|
43 |
+
# Get logger
|
44 |
+
logger = logging.getLogger(logger_name)
|
45 |
+
logger.setLevel(logging.INFO)
|
46 |
+
|
47 |
+
# Add a file handler for all loggers
|
48 |
+
if handler is None:
|
49 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
50 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
51 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
52 |
+
filename, when='D', utc=True)
|
53 |
+
handler.setFormatter(formatter)
|
54 |
+
|
55 |
+
for name, item in logging.root.manager.loggerDict.items():
|
56 |
+
if isinstance(item, logging.Logger):
|
57 |
+
item.addHandler(handler)
|
58 |
+
|
59 |
+
return logger
|
60 |
+
|
61 |
+
|
62 |
+
class StreamToLogger(object):
|
63 |
+
"""
|
64 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
65 |
+
"""
|
66 |
+
def __init__(self, logger, log_level=logging.INFO):
|
67 |
+
self.terminal = sys.stdout
|
68 |
+
self.logger = logger
|
69 |
+
self.log_level = log_level
|
70 |
+
self.linebuf = ''
|
71 |
+
|
72 |
+
def __getattr__(self, attr):
|
73 |
+
return getattr(self.terminal, attr)
|
74 |
+
|
75 |
+
def write(self, buf):
|
76 |
+
temp_linebuf = self.linebuf + buf
|
77 |
+
self.linebuf = ''
|
78 |
+
for line in temp_linebuf.splitlines(True):
|
79 |
+
# From the io.TextIOWrapper docs:
|
80 |
+
# On output, if newline is None, any '\n' characters written
|
81 |
+
# are translated to the system default line separator.
|
82 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
83 |
+
# translates them so this is still cross platform.
|
84 |
+
if line[-1] == '\n':
|
85 |
+
self.logger.log(self.log_level, line.rstrip())
|
86 |
+
else:
|
87 |
+
self.linebuf += line
|
88 |
+
|
89 |
+
def flush(self):
|
90 |
+
if self.linebuf != '':
|
91 |
+
self.logger.log(self.log_level, self.linebuf.rstrip())
|
92 |
+
self.linebuf = ''
|
93 |
+
|
94 |
+
|
95 |
+
def disable_torch_init():
|
96 |
+
"""
|
97 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
98 |
+
"""
|
99 |
+
import torch
|
100 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
101 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
102 |
+
|
103 |
+
|
104 |
+
def violates_moderation(text):
|
105 |
+
"""
|
106 |
+
Check whether the text violates OpenAI moderation API.
|
107 |
+
"""
|
108 |
+
url = "https://api.openai.com/v1/moderations"
|
109 |
+
headers = {"Content-Type": "application/json",
|
110 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
|
111 |
+
text = text.replace("\n", "")
|
112 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
113 |
+
data = data.encode("utf-8")
|
114 |
+
try:
|
115 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
116 |
+
flagged = ret.json()["results"][0]["flagged"]
|
117 |
+
except requests.exceptions.RequestException as e:
|
118 |
+
flagged = False
|
119 |
+
except KeyError as e:
|
120 |
+
flagged = False
|
121 |
+
|
122 |
+
return flagged
|
123 |
+
|
124 |
+
|
125 |
+
def pretty_print_semaphore(semaphore):
|
126 |
+
if semaphore is None:
|
127 |
+
return "None"
|
128 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
q_align/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import MPLUGOwl2LlamaForCausalLM
|
q_align/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (213 Bytes). View file
|
|
q_align/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (218 Bytes). View file
|
|
q_align/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (184 Bytes). View file
|
|
q_align/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (358 Bytes). View file
|
|
q_align/__pycache__/constants.cpython-311.pyc
ADDED
Binary file (371 Bytes). View file
|
|
q_align/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (329 Bytes). View file
|
|
q_align/__pycache__/conversation.cpython-310.pyc
ADDED
Binary file (8.55 kB). View file
|
|
q_align/__pycache__/conversation.cpython-311.pyc
ADDED
Binary file (14.8 kB). View file
|
|
q_align/__pycache__/conversation.cpython-39.pyc
ADDED
Binary file (8.53 kB). View file
|
|
q_align/__pycache__/mm_utils.cpython-310.pyc
ADDED
Binary file (4.57 kB). View file
|
|
q_align/__pycache__/mm_utils.cpython-311.pyc
ADDED
Binary file (8.4 kB). View file
|
|
q_align/__pycache__/mm_utils.cpython-39.pyc
ADDED
Binary file (4.51 kB). View file
|
|
q_align/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (6.92 kB). View file
|
|
q_align/constants.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "./demo_logs"
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
q_align/conversation.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
from q_align.constants import DEFAULT_IMAGE_TOKEN
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
TWO_NO_SYS = auto()
|
11 |
+
MPT = auto()
|
12 |
+
PLAIN = auto()
|
13 |
+
LLAMA_2 = auto()
|
14 |
+
|
15 |
+
|
16 |
+
@dataclasses.dataclass
|
17 |
+
class Conversation:
|
18 |
+
"""A class that keeps all conversation history."""
|
19 |
+
system: str
|
20 |
+
roles: List[str]
|
21 |
+
messages: List[List[str]]
|
22 |
+
offset: int
|
23 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
24 |
+
sep: str = "###"
|
25 |
+
sep2: str = None
|
26 |
+
version: str = "Unknown"
|
27 |
+
|
28 |
+
skip_next: bool = False
|
29 |
+
|
30 |
+
def get_prompt(self):
|
31 |
+
messages = self.messages
|
32 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
33 |
+
messages = self.messages.copy()
|
34 |
+
init_role, init_msg = messages[0].copy()
|
35 |
+
# init_msg = init_msg[0].replace("<image>", "").strip()
|
36 |
+
# if 'mmtag' in self.version:
|
37 |
+
# messages[0] = (init_role, init_msg)
|
38 |
+
# messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
39 |
+
# messages.insert(1, (self.roles[1], "Received."))
|
40 |
+
# else:
|
41 |
+
# messages[0] = (init_role, "<image>\n" + init_msg)
|
42 |
+
init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
43 |
+
messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
|
44 |
+
|
45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
46 |
+
ret = self.system + self.sep
|
47 |
+
for role, message in messages:
|
48 |
+
if message:
|
49 |
+
if type(message) is tuple:
|
50 |
+
message, _, _ = message
|
51 |
+
ret += role + ": " + message + self.sep
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
55 |
+
seps = [self.sep, self.sep2]
|
56 |
+
ret = self.system + seps[0]
|
57 |
+
for i, (role, message) in enumerate(messages):
|
58 |
+
if message:
|
59 |
+
if type(message) is tuple:
|
60 |
+
message, _, _ = message
|
61 |
+
ret += role + ": " + message + seps[i % 2]
|
62 |
+
else:
|
63 |
+
ret += role + ":"
|
64 |
+
elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
|
65 |
+
seps = [self.sep, self.sep2]
|
66 |
+
ret = ""
|
67 |
+
for i, (role, message) in enumerate(messages):
|
68 |
+
if message:
|
69 |
+
if type(message) is tuple:
|
70 |
+
message, _, _ = message
|
71 |
+
ret += role + ": " + message + seps[i % 2]
|
72 |
+
else:
|
73 |
+
ret += role + ":"
|
74 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
75 |
+
ret = self.system + self.sep
|
76 |
+
for role, message in messages:
|
77 |
+
if message:
|
78 |
+
if type(message) is tuple:
|
79 |
+
message, _, _ = message
|
80 |
+
ret += role + message + self.sep
|
81 |
+
else:
|
82 |
+
ret += role
|
83 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
84 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
85 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
86 |
+
ret = ""
|
87 |
+
|
88 |
+
for i, (role, message) in enumerate(messages):
|
89 |
+
if i == 0:
|
90 |
+
assert message, "first message should not be none"
|
91 |
+
assert role == self.roles[0], "first message should come from user"
|
92 |
+
if message:
|
93 |
+
if type(message) is tuple:
|
94 |
+
message, _, _ = message
|
95 |
+
if i == 0: message = wrap_sys(self.system) + message
|
96 |
+
if i % 2 == 0:
|
97 |
+
message = wrap_inst(message)
|
98 |
+
ret += self.sep + message
|
99 |
+
else:
|
100 |
+
ret += " " + message + " " + self.sep2
|
101 |
+
else:
|
102 |
+
ret += ""
|
103 |
+
ret = ret.lstrip(self.sep)
|
104 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
105 |
+
seps = [self.sep, self.sep2]
|
106 |
+
ret = self.system
|
107 |
+
for i, (role, message) in enumerate(messages):
|
108 |
+
if message:
|
109 |
+
if type(message) is tuple:
|
110 |
+
message, _, _ = message
|
111 |
+
ret += message + seps[i % 2]
|
112 |
+
else:
|
113 |
+
ret += ""
|
114 |
+
else:
|
115 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
116 |
+
|
117 |
+
return ret
|
118 |
+
|
119 |
+
def append_message(self, role, message):
|
120 |
+
self.messages.append([role, message])
|
121 |
+
|
122 |
+
def get_images(self, return_pil=False):
|
123 |
+
images = []
|
124 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
125 |
+
if i % 2 == 0:
|
126 |
+
if type(msg) is tuple:
|
127 |
+
import base64
|
128 |
+
from io import BytesIO
|
129 |
+
from PIL import Image
|
130 |
+
msg, image, image_process_mode = msg
|
131 |
+
if image_process_mode == "Pad":
|
132 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
133 |
+
width, height = pil_img.size
|
134 |
+
if width == height:
|
135 |
+
return pil_img
|
136 |
+
elif width > height:
|
137 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
138 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
139 |
+
return result
|
140 |
+
else:
|
141 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
142 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
143 |
+
return result
|
144 |
+
image = expand2square(image)
|
145 |
+
elif image_process_mode in ["Default", "Crop"]:
|
146 |
+
pass
|
147 |
+
elif image_process_mode == "Resize":
|
148 |
+
image = image.resize((336, 336))
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
151 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
152 |
+
aspect_ratio = max_hw / min_hw
|
153 |
+
max_len, min_len = 800, 400
|
154 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
155 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
156 |
+
W, H = image.size
|
157 |
+
if longest_edge != max(image.size):
|
158 |
+
if H > W:
|
159 |
+
H, W = longest_edge, shortest_edge
|
160 |
+
else:
|
161 |
+
H, W = shortest_edge, longest_edge
|
162 |
+
image = image.resize((W, H))
|
163 |
+
if return_pil:
|
164 |
+
images.append(image)
|
165 |
+
else:
|
166 |
+
buffered = BytesIO()
|
167 |
+
image.save(buffered, format="PNG")
|
168 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
169 |
+
images.append(img_b64_str)
|
170 |
+
return images
|
171 |
+
|
172 |
+
def to_gradio_chatbot(self):
|
173 |
+
ret = []
|
174 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
175 |
+
if i % 2 == 0:
|
176 |
+
if type(msg) is tuple:
|
177 |
+
import base64
|
178 |
+
from io import BytesIO
|
179 |
+
msg, image, image_process_mode = msg
|
180 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
181 |
+
aspect_ratio = max_hw / min_hw
|
182 |
+
max_len, min_len = 800, 400
|
183 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
184 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
185 |
+
W, H = image.size
|
186 |
+
if H > W:
|
187 |
+
H, W = longest_edge, shortest_edge
|
188 |
+
else:
|
189 |
+
H, W = shortest_edge, longest_edge
|
190 |
+
image = image.resize((W, H))
|
191 |
+
buffered = BytesIO()
|
192 |
+
image.save(buffered, format="JPEG")
|
193 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
194 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
195 |
+
msg = img_str + msg.replace('<|image|>', '').strip()
|
196 |
+
ret.append([msg, None])
|
197 |
+
else:
|
198 |
+
ret.append([msg, None])
|
199 |
+
else:
|
200 |
+
ret[-1][-1] = msg
|
201 |
+
return ret
|
202 |
+
|
203 |
+
def copy(self):
|
204 |
+
return Conversation(
|
205 |
+
system=self.system,
|
206 |
+
roles=self.roles,
|
207 |
+
messages=[[x, y] for x, y in self.messages],
|
208 |
+
offset=self.offset,
|
209 |
+
sep_style=self.sep_style,
|
210 |
+
sep=self.sep,
|
211 |
+
sep2=self.sep2,
|
212 |
+
version=self.version)
|
213 |
+
|
214 |
+
def dict(self):
|
215 |
+
if len(self.get_images()) > 0:
|
216 |
+
return {
|
217 |
+
"system": self.system,
|
218 |
+
"roles": self.roles,
|
219 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
220 |
+
"offset": self.offset,
|
221 |
+
"sep": self.sep,
|
222 |
+
"sep2": self.sep2,
|
223 |
+
}
|
224 |
+
return {
|
225 |
+
"system": self.system,
|
226 |
+
"roles": self.roles,
|
227 |
+
"messages": self.messages,
|
228 |
+
"offset": self.offset,
|
229 |
+
"sep": self.sep,
|
230 |
+
"sep2": self.sep2,
|
231 |
+
}
|
232 |
+
|
233 |
+
|
234 |
+
conv_vicuna_v0 = Conversation(
|
235 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
236 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
237 |
+
roles=("Human", "Assistant"),
|
238 |
+
messages=(
|
239 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
240 |
+
("Assistant",
|
241 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
242 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
243 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
244 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
245 |
+
"renewable and non-renewable energy sources:\n"
|
246 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
247 |
+
"energy sources are finite and will eventually run out.\n"
|
248 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
249 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
250 |
+
"and other negative effects.\n"
|
251 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
252 |
+
"have lower operational costs than non-renewable sources.\n"
|
253 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
254 |
+
"locations than non-renewable sources.\n"
|
255 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
256 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
257 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
258 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
259 |
+
),
|
260 |
+
offset=2,
|
261 |
+
sep_style=SeparatorStyle.SINGLE,
|
262 |
+
sep="###",
|
263 |
+
)
|
264 |
+
|
265 |
+
conv_vicuna_v1 = Conversation(
|
266 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
267 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
268 |
+
roles=("USER", "ASSISTANT"),
|
269 |
+
version="v1",
|
270 |
+
messages=(),
|
271 |
+
offset=0,
|
272 |
+
sep_style=SeparatorStyle.TWO,
|
273 |
+
sep=" ",
|
274 |
+
sep2="</s>",
|
275 |
+
)
|
276 |
+
|
277 |
+
conv_mplug_owl2 = Conversation(
|
278 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
279 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
280 |
+
roles=("USER", "ASSISTANT"),
|
281 |
+
version="v1",
|
282 |
+
messages=(),
|
283 |
+
offset=0,
|
284 |
+
sep_style=SeparatorStyle.TWO_NO_SYS,
|
285 |
+
sep=" ",
|
286 |
+
sep2="</s>",
|
287 |
+
)
|
288 |
+
|
289 |
+
# default_conversation = conv_vicuna_v1
|
290 |
+
default_conversation = conv_mplug_owl2
|
291 |
+
conv_templates = {
|
292 |
+
"default": conv_vicuna_v0,
|
293 |
+
"v0": conv_vicuna_v0,
|
294 |
+
"v1": conv_vicuna_v1,
|
295 |
+
"vicuna_v1": conv_vicuna_v1,
|
296 |
+
"mplug_owl2": conv_mplug_owl2,
|
297 |
+
}
|
298 |
+
|
299 |
+
|
300 |
+
if __name__ == "__main__":
|
301 |
+
print(default_conversation.get_prompt())
|
q_align/evaluate/.ipynb_checkpoints/iaa_eval-checkpoint.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from PIL import ImageFile
|
11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
12 |
+
|
13 |
+
import requests
|
14 |
+
from PIL import Image
|
15 |
+
from io import BytesIO
|
16 |
+
from transformers import TextStreamer
|
17 |
+
|
18 |
+
from scipy.stats import spearmanr, pearsonr
|
19 |
+
|
20 |
+
|
21 |
+
import json
|
22 |
+
from tqdm import tqdm
|
23 |
+
from collections import defaultdict
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
def wa5(logits):
|
28 |
+
import numpy as np
|
29 |
+
logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
|
30 |
+
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
|
31 |
+
return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def disable_torch_init():
|
37 |
+
"""
|
38 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
39 |
+
"""
|
40 |
+
import torch
|
41 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
42 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
43 |
+
|
44 |
+
|
45 |
+
def load_image(image_file):
|
46 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
47 |
+
response = requests.get(image_file)
|
48 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
49 |
+
else:
|
50 |
+
image = Image.open(image_file).convert('RGB')
|
51 |
+
return image
|
52 |
+
|
53 |
+
|
54 |
+
def main(args):
|
55 |
+
# Model
|
56 |
+
disable_torch_init()
|
57 |
+
|
58 |
+
model_name = get_model_name_from_path(args.model_path)
|
59 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
60 |
+
|
61 |
+
|
62 |
+
import json
|
63 |
+
|
64 |
+
|
65 |
+
image_path = "playground/data/"
|
66 |
+
|
67 |
+
|
68 |
+
json_prefix = "playground/data/test_jsons/"
|
69 |
+
jsons = [
|
70 |
+
json_prefix + "test_ava.json",
|
71 |
+
]
|
72 |
+
|
73 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
74 |
+
|
75 |
+
|
76 |
+
conv_mode = "mplug_owl2"
|
77 |
+
|
78 |
+
inp = "How would you rate the aesthetics of this image?"
|
79 |
+
|
80 |
+
conv = conv_templates[conv_mode].copy()
|
81 |
+
inp = DEFAULT_IMAGE_TOKEN + inp
|
82 |
+
conv.append_message(conv.roles[0], inp)
|
83 |
+
image = None
|
84 |
+
|
85 |
+
conv.append_message(conv.roles[1], None)
|
86 |
+
prompt = conv.get_prompt() + " The aesthetics of the image is"
|
87 |
+
|
88 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
89 |
+
print(toks)
|
90 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
91 |
+
print(ids_)
|
92 |
+
|
93 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
94 |
+
|
95 |
+
for json_ in jsons:
|
96 |
+
with open(json_) as f:
|
97 |
+
iqadata = json.load(f)
|
98 |
+
|
99 |
+
image_tensors = []
|
100 |
+
batch_data = []
|
101 |
+
prs, gts = [], []
|
102 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
103 |
+
filename = llddata["image"]
|
104 |
+
llddata["logits"] = defaultdict(float)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
image = load_image(image_path + filename)
|
109 |
+
def expand2square(pil_img, background_color):
|
110 |
+
width, height = pil_img.size
|
111 |
+
if width == height:
|
112 |
+
return pil_img
|
113 |
+
elif width > height:
|
114 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
115 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
116 |
+
return result
|
117 |
+
else:
|
118 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
119 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
120 |
+
return result
|
121 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
122 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
123 |
+
|
124 |
+
image_tensors.append(image_tensor)
|
125 |
+
batch_data.append(llddata)
|
126 |
+
|
127 |
+
if i % 8 == 7 or i == len(iqadata) - 1:
|
128 |
+
with torch.inference_mode():
|
129 |
+
output_logits = model(input_ids.repeat(len(image_tensors), 1),
|
130 |
+
images=torch.cat(image_tensors, 0))["logits"][:,-1]
|
131 |
+
|
132 |
+
for j, xllddata in enumerate(batch_data):
|
133 |
+
for tok, id_ in zip(toks, ids_):
|
134 |
+
xllddata["logits"][tok] += output_logits[j,id_].item()
|
135 |
+
xllddata["score"] = wa5(xllddata["logits"])
|
136 |
+
# print(llddata)
|
137 |
+
prs.append(xllddata["score"])
|
138 |
+
gts.append(xllddata["gt_score"])
|
139 |
+
json_ = json_.replace("combined/", "combined-")
|
140 |
+
with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
|
141 |
+
json.dump(xllddata, wf)
|
142 |
+
|
143 |
+
image_tensors = []
|
144 |
+
batch_data = []
|
145 |
+
|
146 |
+
#if i > 0 and i % 200 == 0:
|
147 |
+
# print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
|
148 |
+
print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
parser = argparse.ArgumentParser()
|
153 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
154 |
+
parser.add_argument("--model-base", type=str, default=None)
|
155 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
156 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
157 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
158 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
159 |
+
parser.add_argument("--load-8bit", action="store_true")
|
160 |
+
parser.add_argument("--load-4bit", action="store_true")
|
161 |
+
parser.add_argument("--debug", action="store_true")
|
162 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
163 |
+
args = parser.parse_args()
|
164 |
+
main(args)
|
q_align/evaluate/.ipynb_checkpoints/iqa4vqa_eval-checkpoint.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
from decord import VideoReader
|
17 |
+
|
18 |
+
|
19 |
+
import json
|
20 |
+
from tqdm import tqdm
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
import os
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def disable_torch_init():
|
29 |
+
"""
|
30 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
31 |
+
"""
|
32 |
+
import torch
|
33 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
34 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
35 |
+
|
36 |
+
|
37 |
+
def load_video(video_file):
|
38 |
+
vr = VideoReader(video_file)
|
39 |
+
|
40 |
+
# Get video frame rate
|
41 |
+
fps = vr.get_avg_fps()
|
42 |
+
|
43 |
+
# Calculate frame indices for 1fps
|
44 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
45 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
46 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
47 |
+
|
48 |
+
|
49 |
+
def main(args):
|
50 |
+
# Model
|
51 |
+
disable_torch_init()
|
52 |
+
|
53 |
+
model_name = get_model_name_from_path(args.model_path)
|
54 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
55 |
+
|
56 |
+
|
57 |
+
import json
|
58 |
+
|
59 |
+
|
60 |
+
image_paths = [
|
61 |
+
"playground/data/",
|
62 |
+
"playground/data/",
|
63 |
+
"playground/data/KoNViD_1k_videos/",
|
64 |
+
"playground/data/maxwell/",
|
65 |
+
]
|
66 |
+
|
67 |
+
json_prefix = "playground/data/test_jsons/"
|
68 |
+
jsons = [
|
69 |
+
json_prefix + "test_lsvq.json",
|
70 |
+
json_prefix + "test_lsvq_1080p.json",
|
71 |
+
json_prefix + "konvid.json",
|
72 |
+
json_prefix + "maxwell_test.json",
|
73 |
+
]
|
74 |
+
|
75 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
76 |
+
|
77 |
+
|
78 |
+
conv_mode = "mplug_owl2"
|
79 |
+
|
80 |
+
inp = "How would you rate the quality of this image?"
|
81 |
+
|
82 |
+
conv = conv_templates[conv_mode].copy()
|
83 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
84 |
+
conv.append_message(conv.roles[0], inp)
|
85 |
+
image = None
|
86 |
+
|
87 |
+
conv.append_message(conv.roles[1], None)
|
88 |
+
prompt = conv.get_prompt() + " The quality of the image is"
|
89 |
+
|
90 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
91 |
+
print(toks)
|
92 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
93 |
+
print(ids_)
|
94 |
+
|
95 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
96 |
+
|
97 |
+
for image_path, json_ in zip(image_paths, jsons):
|
98 |
+
with open(json_) as f:
|
99 |
+
iqadata = json.load(f)
|
100 |
+
try:
|
101 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
102 |
+
filename = llddata["img_path"]
|
103 |
+
llddata["logits"] = defaultdict(float)
|
104 |
+
|
105 |
+
image = load_video(image_path + filename)
|
106 |
+
def expand2square(pil_img, background_color):
|
107 |
+
width, height = pil_img.size
|
108 |
+
if width == height:
|
109 |
+
return pil_img
|
110 |
+
elif width > height:
|
111 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
112 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
113 |
+
return result
|
114 |
+
else:
|
115 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
116 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
117 |
+
return result
|
118 |
+
image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
|
119 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
120 |
+
|
121 |
+
|
122 |
+
if True:
|
123 |
+
with torch.inference_mode():
|
124 |
+
output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
|
125 |
+
images=image_tensor)["logits"][:,-1]
|
126 |
+
|
127 |
+
for tok, id_ in zip(toks, ids_):
|
128 |
+
llddata["logits"][tok] += output_logits.mean(0)[id_].item()
|
129 |
+
# print(llddata)
|
130 |
+
json_ = json_.replace("combined/", "combined-")
|
131 |
+
with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
|
132 |
+
json.dump(llddata, wf)
|
133 |
+
except:
|
134 |
+
continue
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
parser = argparse.ArgumentParser()
|
139 |
+
parser.add_argument("--model-path", type=str, default="q-future/q-align-image")
|
140 |
+
parser.add_argument("--model-base", type=str, default=None)
|
141 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
142 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
143 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
144 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
145 |
+
parser.add_argument("--load-8bit", action="store_true")
|
146 |
+
parser.add_argument("--load-4bit", action="store_true")
|
147 |
+
parser.add_argument("--debug", action="store_true")
|
148 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
149 |
+
args = parser.parse_args()
|
150 |
+
main(args)
|
q_align/evaluate/.ipynb_checkpoints/iqa_eval-checkpoint.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
import json
|
17 |
+
from tqdm import tqdm
|
18 |
+
from collections import defaultdict
|
19 |
+
|
20 |
+
import os
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def disable_torch_init():
|
26 |
+
"""
|
27 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
28 |
+
"""
|
29 |
+
import torch
|
30 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
31 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
32 |
+
|
33 |
+
|
34 |
+
def load_image(image_file):
|
35 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
36 |
+
response = requests.get(image_file)
|
37 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
38 |
+
else:
|
39 |
+
image = Image.open(image_file).convert('RGB')
|
40 |
+
return image
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# Model
|
45 |
+
disable_torch_init()
|
46 |
+
|
47 |
+
model_name = get_model_name_from_path(args.model_path)
|
48 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
49 |
+
|
50 |
+
|
51 |
+
import json
|
52 |
+
|
53 |
+
|
54 |
+
image_path = "playground/data/"
|
55 |
+
|
56 |
+
|
57 |
+
json_prefix = "playground/data/test_jsons/"
|
58 |
+
jsons = [
|
59 |
+
json_prefix + "test_imagerewarddb.json",
|
60 |
+
json_prefix + "test_koniq.json",
|
61 |
+
json_prefix + "test_spaq.json",
|
62 |
+
json_prefix + "test_kadid.json",
|
63 |
+
json_prefix + "livec.json",
|
64 |
+
json_prefix + "agi.json",
|
65 |
+
json_prefix + "live.json",
|
66 |
+
json_prefix + "csiq.json",
|
67 |
+
]
|
68 |
+
|
69 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
70 |
+
|
71 |
+
|
72 |
+
conv_mode = "mplug_owl2"
|
73 |
+
|
74 |
+
inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?"
|
75 |
+
|
76 |
+
conv = conv_templates[conv_mode].copy()
|
77 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
78 |
+
conv.append_message(conv.roles[0], inp)
|
79 |
+
image = None
|
80 |
+
|
81 |
+
conv.append_message(conv.roles[1], None)
|
82 |
+
prompt = conv.get_prompt() + " The quality of the image is"
|
83 |
+
|
84 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
85 |
+
print(toks)
|
86 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
87 |
+
print(ids_)
|
88 |
+
|
89 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
90 |
+
|
91 |
+
for json_ in jsons:
|
92 |
+
with open(json_) as f:
|
93 |
+
iqadata = json.load(f)
|
94 |
+
|
95 |
+
image_tensors = []
|
96 |
+
batch_data = []
|
97 |
+
|
98 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
99 |
+
if True:
|
100 |
+
try:
|
101 |
+
filename = llddata["image"]
|
102 |
+
except:
|
103 |
+
filename = llddata["img_path"]
|
104 |
+
llddata["logits"] = defaultdict(float)
|
105 |
+
|
106 |
+
image = load_image(image_path + filename)
|
107 |
+
def expand2square(pil_img, background_color):
|
108 |
+
width, height = pil_img.size
|
109 |
+
if width == height:
|
110 |
+
return pil_img
|
111 |
+
elif width > height:
|
112 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
113 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
114 |
+
return result
|
115 |
+
else:
|
116 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
117 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
118 |
+
return result
|
119 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
120 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
121 |
+
|
122 |
+
image_tensors.append(image_tensor)
|
123 |
+
batch_data.append(llddata)
|
124 |
+
|
125 |
+
if i % 8 == 7 or i == len(iqadata) - 1:
|
126 |
+
with torch.inference_mode():
|
127 |
+
output_logits = model(input_ids.repeat(len(image_tensors), 1),
|
128 |
+
images=torch.cat(image_tensors, 0))["logits"][:,-1]
|
129 |
+
|
130 |
+
for j, xllddata in enumerate(batch_data):
|
131 |
+
for tok, id_ in zip(toks, ids_):
|
132 |
+
xllddata["logits"][tok] += output_logits[j,id_].item()
|
133 |
+
# print(llddata)
|
134 |
+
json_ = json_.replace("combined/", "combined-")
|
135 |
+
with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
|
136 |
+
json.dump(xllddata, wf)
|
137 |
+
|
138 |
+
image_tensors = []
|
139 |
+
batch_data = []
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = argparse.ArgumentParser()
|
145 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
146 |
+
parser.add_argument("--model-base", type=str, default=None)
|
147 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
148 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
149 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
150 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
151 |
+
parser.add_argument("--load-8bit", action="store_true")
|
152 |
+
parser.add_argument("--load-4bit", action="store_true")
|
153 |
+
parser.add_argument("--debug", action="store_true")
|
154 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
155 |
+
args = parser.parse_args()
|
156 |
+
main(args)
|
q_align/evaluate/.ipynb_checkpoints/scorer-checkpoint.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from q_align.model.builder import load_pretrained_model
|
9 |
+
|
10 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
11 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
12 |
+
|
13 |
+
def load_video(video_file):
|
14 |
+
from decord import VideoReader
|
15 |
+
vr = VideoReader(video_file)
|
16 |
+
|
17 |
+
# Get video frame rate
|
18 |
+
fps = vr.get_avg_fps()
|
19 |
+
|
20 |
+
# Calculate frame indices for 1fps
|
21 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
22 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
23 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
24 |
+
|
25 |
+
|
26 |
+
class QAlignScorer(nn.Module):
|
27 |
+
def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
|
28 |
+
super().__init__()
|
29 |
+
if model is None:
|
30 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
31 |
+
prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
|
32 |
+
|
33 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
34 |
+
self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
|
35 |
+
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.model = model
|
38 |
+
self.image_processor = image_processor
|
39 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
40 |
+
|
41 |
+
def expand2square(self, pil_img, background_color):
|
42 |
+
width, height = pil_img.size
|
43 |
+
if width == height:
|
44 |
+
return pil_img
|
45 |
+
elif width > height:
|
46 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
47 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
48 |
+
return result
|
49 |
+
else:
|
50 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
51 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
52 |
+
return result
|
53 |
+
|
54 |
+
def forward(self, image: List[Image.Image]):
|
55 |
+
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
|
56 |
+
with torch.inference_mode():
|
57 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
|
58 |
+
output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
|
59 |
+
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
60 |
+
|
61 |
+
return torch.softmax(output_logits, -1) #@ self.weight_tensor
|
62 |
+
|
63 |
+
|
64 |
+
class QAlignAestheticScorer(nn.Module):
|
65 |
+
def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
|
66 |
+
super().__init__()
|
67 |
+
if model is None:
|
68 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
69 |
+
prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is"
|
70 |
+
|
71 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
72 |
+
self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
|
73 |
+
|
74 |
+
self.tokenizer = tokenizer
|
75 |
+
self.model = model
|
76 |
+
self.image_processor = image_processor
|
77 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
78 |
+
|
79 |
+
def expand2square(self, pil_img, background_color):
|
80 |
+
width, height = pil_img.size
|
81 |
+
if width == height:
|
82 |
+
return pil_img
|
83 |
+
elif width > height:
|
84 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
85 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
86 |
+
return result
|
87 |
+
else:
|
88 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
89 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
90 |
+
return result
|
91 |
+
|
92 |
+
def forward(self, image: List[Image.Image]):
|
93 |
+
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
|
94 |
+
with torch.inference_mode():
|
95 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
|
96 |
+
output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
|
97 |
+
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
98 |
+
|
99 |
+
return torch.softmax(output_logits, -1) #@ self.weight_tensor
|
100 |
+
|
101 |
+
class QAlignVideoScorer(nn.Module):
|
102 |
+
def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
|
103 |
+
super().__init__()
|
104 |
+
if model is None:
|
105 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
106 |
+
prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is"
|
107 |
+
|
108 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
109 |
+
self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
|
110 |
+
|
111 |
+
self.tokenizer = tokenizer
|
112 |
+
self.model = model
|
113 |
+
self.image_processor = image_processor
|
114 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
115 |
+
|
116 |
+
def expand2square(self, pil_img, background_color):
|
117 |
+
width, height = pil_img.size
|
118 |
+
if width == height:
|
119 |
+
return pil_img
|
120 |
+
elif width > height:
|
121 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
122 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
123 |
+
return result
|
124 |
+
else:
|
125 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
126 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
127 |
+
return result
|
128 |
+
|
129 |
+
def forward(self, video: List[List[Image.Image]]):
|
130 |
+
video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video]
|
131 |
+
with torch.inference_mode():
|
132 |
+
video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
|
133 |
+
output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
|
134 |
+
images=video_tensors)["logits"][:,-1, self.preferential_ids_]
|
135 |
+
return torch.softmax(output_logits, -1) #@ self.weight_tensor
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
import argparse
|
140 |
+
|
141 |
+
parser = argparse.ArgumentParser()
|
142 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
143 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
144 |
+
parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
|
145 |
+
parser.add_argument("--aesthetic", action="store_true")
|
146 |
+
parser.add_argument("--video", action="store_true")
|
147 |
+
args = parser.parse_args()
|
148 |
+
|
149 |
+
if args.video:
|
150 |
+
scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device)
|
151 |
+
print(scorer([load_video(args.img_path)]).tolist())
|
152 |
+
else:
|
153 |
+
scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device)
|
154 |
+
print(scorer([Image.open(args.img_path)]).tolist())
|
155 |
+
|
q_align/evaluate/.ipynb_checkpoints/vqa_eval-checkpoint.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
|
17 |
+
from scipy.stats import spearmanr, pearsonr
|
18 |
+
|
19 |
+
import json
|
20 |
+
from tqdm import tqdm
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
import os
|
24 |
+
|
25 |
+
def wa5(logits):
|
26 |
+
import numpy as np
|
27 |
+
logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
|
28 |
+
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
|
29 |
+
return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def disable_torch_init():
|
34 |
+
"""
|
35 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
36 |
+
"""
|
37 |
+
import torch
|
38 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
39 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
40 |
+
|
41 |
+
|
42 |
+
def load_video(video_file):
|
43 |
+
from decord import VideoReader
|
44 |
+
vr = VideoReader(video_file)
|
45 |
+
|
46 |
+
# Get video frame rate
|
47 |
+
fps = vr.get_avg_fps()
|
48 |
+
|
49 |
+
# Calculate frame indices for 1fps
|
50 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
51 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
52 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
53 |
+
|
54 |
+
|
55 |
+
def main(args):
|
56 |
+
# Model
|
57 |
+
disable_torch_init()
|
58 |
+
|
59 |
+
model_name = get_model_name_from_path(args.model_path)
|
60 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
61 |
+
|
62 |
+
|
63 |
+
import json
|
64 |
+
|
65 |
+
|
66 |
+
image_paths = [
|
67 |
+
#"playground/data/",
|
68 |
+
#"playground/data/",
|
69 |
+
"playground/data/KoNViD_1k_videos/",
|
70 |
+
"playground/data/maxwell/",
|
71 |
+
|
72 |
+
]
|
73 |
+
|
74 |
+
json_prefix = "playground/data/test_jsons/"
|
75 |
+
jsons = [
|
76 |
+
#json_prefix + "test_lsvq.json",
|
77 |
+
#json_prefix + "test_lsvq_1080p.json",
|
78 |
+
json_prefix + "konvid.json",
|
79 |
+
json_prefix + "maxwell_test.json",
|
80 |
+
]
|
81 |
+
|
82 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
83 |
+
|
84 |
+
|
85 |
+
conv_mode = "mplug_owl2"
|
86 |
+
|
87 |
+
inp = "How would you rate the quality of this video?"
|
88 |
+
|
89 |
+
conv = conv_templates[conv_mode].copy()
|
90 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
91 |
+
conv.append_message(conv.roles[0], inp)
|
92 |
+
image = None
|
93 |
+
|
94 |
+
conv.append_message(conv.roles[1], None)
|
95 |
+
prompt = conv.get_prompt() + " The quality of the video is"
|
96 |
+
|
97 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
98 |
+
print(toks)
|
99 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
100 |
+
print(ids_)
|
101 |
+
|
102 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
103 |
+
|
104 |
+
for image_path, json_ in zip(image_paths, jsons):
|
105 |
+
with open(json_) as f:
|
106 |
+
iqadata = json.load(f)
|
107 |
+
prs, gts = [], []
|
108 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
109 |
+
try:
|
110 |
+
try:
|
111 |
+
filename = llddata["img_path"]
|
112 |
+
except:
|
113 |
+
filename = llddata["image"]
|
114 |
+
llddata["logits"] = defaultdict(float)
|
115 |
+
|
116 |
+
image = load_video(image_path + filename)
|
117 |
+
def expand2square(pil_img, background_color):
|
118 |
+
width, height = pil_img.size
|
119 |
+
if width == height:
|
120 |
+
return pil_img
|
121 |
+
elif width > height:
|
122 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
123 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
124 |
+
return result
|
125 |
+
else:
|
126 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
127 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
128 |
+
return result
|
129 |
+
image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
|
130 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
131 |
+
|
132 |
+
if True:
|
133 |
+
with torch.inference_mode():
|
134 |
+
output_logits = model(input_ids,
|
135 |
+
images=[image_tensor])["logits"][:,-1]
|
136 |
+
for tok, id_ in zip(toks, ids_):
|
137 |
+
llddata["logits"][tok] += output_logits.mean(0)[id_].item()
|
138 |
+
llddata["score"] = wa5(llddata["logits"])
|
139 |
+
# print(llddata)
|
140 |
+
prs.append(llddata["score"])
|
141 |
+
gts.append(llddata["gt_score"])
|
142 |
+
# print(llddata)
|
143 |
+
json_ = json_.replace("combined/", "combined-")
|
144 |
+
with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
|
145 |
+
json.dump(llddata, wf)
|
146 |
+
|
147 |
+
if i > 0 and i % 200 == 0:
|
148 |
+
print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
|
149 |
+
except:
|
150 |
+
continue
|
151 |
+
print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
parser = argparse.ArgumentParser()
|
156 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
157 |
+
parser.add_argument("--model-base", type=str, default=None)
|
158 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
159 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
160 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
161 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
162 |
+
parser.add_argument("--load-8bit", action="store_true")
|
163 |
+
parser.add_argument("--load-4bit", action="store_true")
|
164 |
+
parser.add_argument("--debug", action="store_true")
|
165 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
166 |
+
args = parser.parse_args()
|
167 |
+
main(args)
|
q_align/evaluate/__pycache__/scorer.cpython-311.pyc
ADDED
Binary file (14.6 kB). View file
|
|
q_align/evaluate/eval.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
import json
|
17 |
+
from tqdm import tqdm
|
18 |
+
from collections import defaultdict
|
19 |
+
|
20 |
+
import os
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def disable_torch_init():
|
25 |
+
"""
|
26 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
27 |
+
"""
|
28 |
+
import torch
|
29 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
30 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
31 |
+
|
32 |
+
|
33 |
+
def load_image(image_file):
|
34 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
35 |
+
response = requests.get(image_file)
|
36 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
37 |
+
else:
|
38 |
+
image = Image.open(image_file).convert('RGB')
|
39 |
+
return image
|
40 |
+
|
41 |
+
|
42 |
+
def main(args):
|
43 |
+
# Model
|
44 |
+
disable_torch_init()
|
45 |
+
|
46 |
+
model_name = get_model_name_from_path(args.model_path)
|
47 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
48 |
+
|
49 |
+
|
50 |
+
import json
|
51 |
+
|
52 |
+
|
53 |
+
image_path = "playground/data/"
|
54 |
+
|
55 |
+
json_prefix = "playground/data/labels/mos_simple/"
|
56 |
+
jsons = [
|
57 |
+
json_prefix + "test_flive.json",
|
58 |
+
json_prefix + "combined/kadid_ref.json",
|
59 |
+
json_prefix + "combined/livec.json",
|
60 |
+
json_prefix + "test_koniq.json",
|
61 |
+
json_prefix + "test_spaq.json",
|
62 |
+
json_prefix + "combined/agi.json",
|
63 |
+
json_prefix + "combined/kadid.json",
|
64 |
+
]
|
65 |
+
|
66 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
67 |
+
|
68 |
+
|
69 |
+
conv_mode = "mplug_owl2"
|
70 |
+
|
71 |
+
inp = "How would you rate the quality of this image?"
|
72 |
+
|
73 |
+
conv = conv_templates[conv_mode].copy()
|
74 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
75 |
+
conv.append_message(conv.roles[0], inp)
|
76 |
+
image = None
|
77 |
+
|
78 |
+
conv.append_message(conv.roles[1], None)
|
79 |
+
prompt = conv.get_prompt() + " The quality of the image is"
|
80 |
+
|
81 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
82 |
+
print(toks)
|
83 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
84 |
+
print(ids_)
|
85 |
+
|
86 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
87 |
+
|
88 |
+
for json_ in jsons:
|
89 |
+
with open(json_) as f:
|
90 |
+
iqadata = json.load(f)
|
91 |
+
|
92 |
+
image_tensors = []
|
93 |
+
batch_data = []
|
94 |
+
|
95 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
96 |
+
#print(f"Evaluating image {i}")
|
97 |
+
#print(prompt)
|
98 |
+
filename = llddata["image"]
|
99 |
+
llddata["logits"] = defaultdict(float)
|
100 |
+
|
101 |
+
image = load_image(image_path + filename)
|
102 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
|
103 |
+
|
104 |
+
image_tensors.append(image_tensor)
|
105 |
+
batch_data.append(llddata)
|
106 |
+
|
107 |
+
if i % 8 == 7 or i == len(iqadata) - 1:
|
108 |
+
with torch.inference_mode():
|
109 |
+
output_logits = model(input_ids.repeat(len(image_tensors), 1),
|
110 |
+
images=torch.cat(image_tensors, 0))["logits"][:,-1]
|
111 |
+
|
112 |
+
for j, xllddata in enumerate(batch_data):
|
113 |
+
for tok, id_ in zip(toks, ids_):
|
114 |
+
xllddata["logits"][tok] += output_logits[j,id_].item()
|
115 |
+
# print(llddata)
|
116 |
+
json_ = json_.replace("combined/", "combined-")
|
117 |
+
# print(f"results/mix-mplug-owl-2-boost_iqa_wu_v2/{json_.split('/')[-1]}")
|
118 |
+
with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
|
119 |
+
json.dump(xllddata, wf)
|
120 |
+
|
121 |
+
image_tensors = []
|
122 |
+
batch_data = []
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
parser = argparse.ArgumentParser()
|
127 |
+
parser.add_argument("--model-path", type=str, default="q-future/q-align-koniq-spaq-v0")
|
128 |
+
parser.add_argument("--model-base", type=str, default=None)
|
129 |
+
parser.add_argument("--device", type=str, default="cuda")
|
130 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
131 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
132 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
133 |
+
parser.add_argument("--load-8bit", action="store_true")
|
134 |
+
parser.add_argument("--load-4bit", action="store_true")
|
135 |
+
parser.add_argument("--debug", action="store_true")
|
136 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
137 |
+
args = parser.parse_args()
|
138 |
+
main(args)
|
q_align/evaluate/iaa_eval.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
from PIL import ImageFile
|
11 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
12 |
+
|
13 |
+
import requests
|
14 |
+
from PIL import Image
|
15 |
+
from io import BytesIO
|
16 |
+
from transformers import TextStreamer
|
17 |
+
|
18 |
+
from scipy.stats import spearmanr, pearsonr
|
19 |
+
|
20 |
+
|
21 |
+
import json
|
22 |
+
from tqdm import tqdm
|
23 |
+
from collections import defaultdict
|
24 |
+
|
25 |
+
import os
|
26 |
+
|
27 |
+
def wa5(logits):
|
28 |
+
import numpy as np
|
29 |
+
logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
|
30 |
+
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
|
31 |
+
return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def disable_torch_init():
|
37 |
+
"""
|
38 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
39 |
+
"""
|
40 |
+
import torch
|
41 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
42 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
43 |
+
|
44 |
+
|
45 |
+
def load_image(image_file):
|
46 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
47 |
+
response = requests.get(image_file)
|
48 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
49 |
+
else:
|
50 |
+
image = Image.open(image_file).convert('RGB')
|
51 |
+
return image
|
52 |
+
|
53 |
+
|
54 |
+
def main(args):
|
55 |
+
# Model
|
56 |
+
disable_torch_init()
|
57 |
+
|
58 |
+
model_name = get_model_name_from_path(args.model_path)
|
59 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
60 |
+
|
61 |
+
|
62 |
+
import json
|
63 |
+
|
64 |
+
|
65 |
+
image_path = "playground/data/"
|
66 |
+
|
67 |
+
|
68 |
+
json_prefix = "playground/data/test_jsons/"
|
69 |
+
jsons = [
|
70 |
+
json_prefix + "test_ava.json",
|
71 |
+
]
|
72 |
+
|
73 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
74 |
+
|
75 |
+
|
76 |
+
conv_mode = "mplug_owl2"
|
77 |
+
|
78 |
+
inp = "How would you rate the aesthetics of this image?"
|
79 |
+
|
80 |
+
conv = conv_templates[conv_mode].copy()
|
81 |
+
inp = DEFAULT_IMAGE_TOKEN + inp
|
82 |
+
conv.append_message(conv.roles[0], inp)
|
83 |
+
image = None
|
84 |
+
|
85 |
+
conv.append_message(conv.roles[1], None)
|
86 |
+
prompt = conv.get_prompt() + " The aesthetics of the image is"
|
87 |
+
|
88 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
89 |
+
print(toks)
|
90 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
91 |
+
print(ids_)
|
92 |
+
|
93 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
94 |
+
|
95 |
+
for json_ in jsons:
|
96 |
+
with open(json_) as f:
|
97 |
+
iqadata = json.load(f)
|
98 |
+
|
99 |
+
image_tensors = []
|
100 |
+
batch_data = []
|
101 |
+
prs, gts = [], []
|
102 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
103 |
+
filename = llddata["image"]
|
104 |
+
llddata["logits"] = defaultdict(float)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
image = load_image(image_path + filename)
|
109 |
+
def expand2square(pil_img, background_color):
|
110 |
+
width, height = pil_img.size
|
111 |
+
if width == height:
|
112 |
+
return pil_img
|
113 |
+
elif width > height:
|
114 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
115 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
116 |
+
return result
|
117 |
+
else:
|
118 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
119 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
120 |
+
return result
|
121 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
122 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
123 |
+
|
124 |
+
image_tensors.append(image_tensor)
|
125 |
+
batch_data.append(llddata)
|
126 |
+
|
127 |
+
if i % 8 == 7 or i == len(iqadata) - 1:
|
128 |
+
with torch.inference_mode():
|
129 |
+
output_logits = model(input_ids.repeat(len(image_tensors), 1),
|
130 |
+
images=torch.cat(image_tensors, 0))["logits"][:,-1]
|
131 |
+
|
132 |
+
for j, xllddata in enumerate(batch_data):
|
133 |
+
for tok, id_ in zip(toks, ids_):
|
134 |
+
xllddata["logits"][tok] += output_logits[j,id_].item()
|
135 |
+
xllddata["score"] = wa5(xllddata["logits"])
|
136 |
+
# print(llddata)
|
137 |
+
prs.append(xllddata["score"])
|
138 |
+
gts.append(xllddata["gt_score"])
|
139 |
+
json_ = json_.replace("combined/", "combined-")
|
140 |
+
with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
|
141 |
+
json.dump(xllddata, wf)
|
142 |
+
|
143 |
+
image_tensors = []
|
144 |
+
batch_data = []
|
145 |
+
|
146 |
+
#if i > 0 and i % 200 == 0:
|
147 |
+
# print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
|
148 |
+
print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
parser = argparse.ArgumentParser()
|
153 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
154 |
+
parser.add_argument("--model-base", type=str, default=None)
|
155 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
156 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
157 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
158 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
159 |
+
parser.add_argument("--load-8bit", action="store_true")
|
160 |
+
parser.add_argument("--load-4bit", action="store_true")
|
161 |
+
parser.add_argument("--debug", action="store_true")
|
162 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
163 |
+
args = parser.parse_args()
|
164 |
+
main(args)
|
q_align/evaluate/iqa4vqa_eval.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
from decord import VideoReader
|
17 |
+
|
18 |
+
|
19 |
+
import json
|
20 |
+
from tqdm import tqdm
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
import os
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def disable_torch_init():
|
29 |
+
"""
|
30 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
31 |
+
"""
|
32 |
+
import torch
|
33 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
34 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
35 |
+
|
36 |
+
|
37 |
+
def load_video(video_file):
|
38 |
+
vr = VideoReader(video_file)
|
39 |
+
|
40 |
+
# Get video frame rate
|
41 |
+
fps = vr.get_avg_fps()
|
42 |
+
|
43 |
+
# Calculate frame indices for 1fps
|
44 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
45 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
46 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
47 |
+
|
48 |
+
|
49 |
+
def main(args):
|
50 |
+
# Model
|
51 |
+
disable_torch_init()
|
52 |
+
|
53 |
+
model_name = get_model_name_from_path(args.model_path)
|
54 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
55 |
+
|
56 |
+
|
57 |
+
import json
|
58 |
+
|
59 |
+
|
60 |
+
image_paths = [
|
61 |
+
"playground/data/",
|
62 |
+
"playground/data/",
|
63 |
+
"playground/data/KoNViD_1k_videos/",
|
64 |
+
"playground/data/maxwell/",
|
65 |
+
]
|
66 |
+
|
67 |
+
json_prefix = "playground/data/test_jsons/"
|
68 |
+
jsons = [
|
69 |
+
json_prefix + "test_lsvq.json",
|
70 |
+
json_prefix + "test_lsvq_1080p.json",
|
71 |
+
json_prefix + "konvid.json",
|
72 |
+
json_prefix + "maxwell_test.json",
|
73 |
+
]
|
74 |
+
|
75 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
76 |
+
|
77 |
+
|
78 |
+
conv_mode = "mplug_owl2"
|
79 |
+
|
80 |
+
inp = "How would you rate the quality of this image?"
|
81 |
+
|
82 |
+
conv = conv_templates[conv_mode].copy()
|
83 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
84 |
+
conv.append_message(conv.roles[0], inp)
|
85 |
+
image = None
|
86 |
+
|
87 |
+
conv.append_message(conv.roles[1], None)
|
88 |
+
prompt = conv.get_prompt() + " The quality of the image is"
|
89 |
+
|
90 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
91 |
+
print(toks)
|
92 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
93 |
+
print(ids_)
|
94 |
+
|
95 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
96 |
+
|
97 |
+
for image_path, json_ in zip(image_paths, jsons):
|
98 |
+
with open(json_) as f:
|
99 |
+
iqadata = json.load(f)
|
100 |
+
try:
|
101 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
102 |
+
filename = llddata["img_path"]
|
103 |
+
llddata["logits"] = defaultdict(float)
|
104 |
+
|
105 |
+
image = load_video(image_path + filename)
|
106 |
+
def expand2square(pil_img, background_color):
|
107 |
+
width, height = pil_img.size
|
108 |
+
if width == height:
|
109 |
+
return pil_img
|
110 |
+
elif width > height:
|
111 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
112 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
113 |
+
return result
|
114 |
+
else:
|
115 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
116 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
117 |
+
return result
|
118 |
+
image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
|
119 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
120 |
+
|
121 |
+
|
122 |
+
if True:
|
123 |
+
with torch.inference_mode():
|
124 |
+
output_logits = model(input_ids.repeat(image_tensor.shape[0], 1),
|
125 |
+
images=image_tensor)["logits"][:,-1]
|
126 |
+
|
127 |
+
for tok, id_ in zip(toks, ids_):
|
128 |
+
llddata["logits"][tok] += output_logits.mean(0)[id_].item()
|
129 |
+
# print(llddata)
|
130 |
+
json_ = json_.replace("combined/", "combined-")
|
131 |
+
with open(f"results/{args.model_path}/{json_.split('/')[-1]}", "a") as wf:
|
132 |
+
json.dump(llddata, wf)
|
133 |
+
except:
|
134 |
+
continue
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
parser = argparse.ArgumentParser()
|
139 |
+
parser.add_argument("--model-path", type=str, default="q-future/q-align-image")
|
140 |
+
parser.add_argument("--model-base", type=str, default=None)
|
141 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
142 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
143 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
144 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
145 |
+
parser.add_argument("--load-8bit", action="store_true")
|
146 |
+
parser.add_argument("--load-4bit", action="store_true")
|
147 |
+
parser.add_argument("--debug", action="store_true")
|
148 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
149 |
+
args = parser.parse_args()
|
150 |
+
main(args)
|
q_align/evaluate/iqa_eval.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
import json
|
17 |
+
from tqdm import tqdm
|
18 |
+
from collections import defaultdict
|
19 |
+
|
20 |
+
import os
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def disable_torch_init():
|
26 |
+
"""
|
27 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
28 |
+
"""
|
29 |
+
import torch
|
30 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
31 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
32 |
+
|
33 |
+
|
34 |
+
def load_image(image_file):
|
35 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
36 |
+
response = requests.get(image_file)
|
37 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
38 |
+
else:
|
39 |
+
image = Image.open(image_file).convert('RGB')
|
40 |
+
return image
|
41 |
+
|
42 |
+
|
43 |
+
def main(args):
|
44 |
+
# Model
|
45 |
+
disable_torch_init()
|
46 |
+
|
47 |
+
model_name = get_model_name_from_path(args.model_path)
|
48 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
49 |
+
|
50 |
+
|
51 |
+
import json
|
52 |
+
|
53 |
+
|
54 |
+
image_path = "playground/data/"
|
55 |
+
|
56 |
+
|
57 |
+
json_prefix = "playground/data/test_jsons/"
|
58 |
+
jsons = [
|
59 |
+
json_prefix + "test_imagerewarddb.json",
|
60 |
+
json_prefix + "test_koniq.json",
|
61 |
+
json_prefix + "test_spaq.json",
|
62 |
+
json_prefix + "test_kadid.json",
|
63 |
+
json_prefix + "livec.json",
|
64 |
+
json_prefix + "agi.json",
|
65 |
+
json_prefix + "live.json",
|
66 |
+
json_prefix + "csiq.json",
|
67 |
+
]
|
68 |
+
|
69 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
70 |
+
|
71 |
+
|
72 |
+
conv_mode = "mplug_owl2"
|
73 |
+
|
74 |
+
inp = "Evaluate the image quality of the following image."#"How would you rate the quality of this image?"
|
75 |
+
|
76 |
+
conv = conv_templates[conv_mode].copy()
|
77 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
78 |
+
conv.append_message(conv.roles[0], inp)
|
79 |
+
image = None
|
80 |
+
|
81 |
+
conv.append_message(conv.roles[1], None)
|
82 |
+
prompt = conv.get_prompt() + " The quality of the image is"
|
83 |
+
|
84 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
85 |
+
print(toks)
|
86 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
87 |
+
print(ids_)
|
88 |
+
|
89 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
90 |
+
|
91 |
+
for json_ in jsons:
|
92 |
+
with open(json_) as f:
|
93 |
+
iqadata = json.load(f)
|
94 |
+
|
95 |
+
image_tensors = []
|
96 |
+
batch_data = []
|
97 |
+
|
98 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
99 |
+
if True:
|
100 |
+
try:
|
101 |
+
filename = llddata["image"]
|
102 |
+
except:
|
103 |
+
filename = llddata["img_path"]
|
104 |
+
llddata["logits"] = defaultdict(float)
|
105 |
+
|
106 |
+
image = load_image(image_path + filename)
|
107 |
+
def expand2square(pil_img, background_color):
|
108 |
+
width, height = pil_img.size
|
109 |
+
if width == height:
|
110 |
+
return pil_img
|
111 |
+
elif width > height:
|
112 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
113 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
114 |
+
return result
|
115 |
+
else:
|
116 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
117 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
118 |
+
return result
|
119 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
120 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
121 |
+
|
122 |
+
image_tensors.append(image_tensor)
|
123 |
+
batch_data.append(llddata)
|
124 |
+
|
125 |
+
if i % 8 == 7 or i == len(iqadata) - 1:
|
126 |
+
with torch.inference_mode():
|
127 |
+
output_logits = model(input_ids.repeat(len(image_tensors), 1),
|
128 |
+
images=torch.cat(image_tensors, 0))["logits"][:,-1]
|
129 |
+
|
130 |
+
for j, xllddata in enumerate(batch_data):
|
131 |
+
for tok, id_ in zip(toks, ids_):
|
132 |
+
xllddata["logits"][tok] += output_logits[j,id_].item()
|
133 |
+
# print(llddata)
|
134 |
+
json_ = json_.replace("combined/", "combined-")
|
135 |
+
with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
|
136 |
+
json.dump(xllddata, wf)
|
137 |
+
|
138 |
+
image_tensors = []
|
139 |
+
batch_data = []
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
parser = argparse.ArgumentParser()
|
145 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
146 |
+
parser.add_argument("--model-base", type=str, default=None)
|
147 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
148 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
149 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
150 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
151 |
+
parser.add_argument("--load-8bit", action="store_true")
|
152 |
+
parser.add_argument("--load-4bit", action="store_true")
|
153 |
+
parser.add_argument("--debug", action="store_true")
|
154 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
155 |
+
args = parser.parse_args()
|
156 |
+
main(args)
|
q_align/evaluate/scorer.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
from q_align.model.builder import load_pretrained_model
|
9 |
+
|
10 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
11 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
12 |
+
|
13 |
+
def load_video(video_file):
|
14 |
+
from decord import VideoReader
|
15 |
+
vr = VideoReader(video_file)
|
16 |
+
|
17 |
+
# Get video frame rate
|
18 |
+
fps = vr.get_avg_fps()
|
19 |
+
|
20 |
+
# Calculate frame indices for 1fps
|
21 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
22 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
23 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
24 |
+
|
25 |
+
|
26 |
+
class QAlignScorer(nn.Module):
|
27 |
+
def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
|
28 |
+
super().__init__()
|
29 |
+
if model is None:
|
30 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
31 |
+
prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
|
32 |
+
|
33 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
34 |
+
self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
|
35 |
+
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.model = model
|
38 |
+
self.image_processor = image_processor
|
39 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
40 |
+
|
41 |
+
def expand2square(self, pil_img, background_color):
|
42 |
+
width, height = pil_img.size
|
43 |
+
if width == height:
|
44 |
+
return pil_img
|
45 |
+
elif width > height:
|
46 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
47 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
48 |
+
return result
|
49 |
+
else:
|
50 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
51 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
52 |
+
return result
|
53 |
+
|
54 |
+
def forward(self, image: List[Image.Image]):
|
55 |
+
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
|
56 |
+
with torch.inference_mode():
|
57 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
|
58 |
+
output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
|
59 |
+
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
60 |
+
|
61 |
+
return torch.softmax(output_logits, -1) #@ self.weight_tensor
|
62 |
+
|
63 |
+
|
64 |
+
class QAlignAestheticScorer(nn.Module):
|
65 |
+
def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
|
66 |
+
super().__init__()
|
67 |
+
if model is None:
|
68 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
69 |
+
prompt = "USER: How would you rate the aesthetics of this image?\n<|image|>\nASSISTANT: The aesthetics of the image is"
|
70 |
+
|
71 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
72 |
+
self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
|
73 |
+
|
74 |
+
self.tokenizer = tokenizer
|
75 |
+
self.model = model
|
76 |
+
self.image_processor = image_processor
|
77 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
78 |
+
|
79 |
+
def expand2square(self, pil_img, background_color):
|
80 |
+
width, height = pil_img.size
|
81 |
+
if width == height:
|
82 |
+
return pil_img
|
83 |
+
elif width > height:
|
84 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
85 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
86 |
+
return result
|
87 |
+
else:
|
88 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
89 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
90 |
+
return result
|
91 |
+
|
92 |
+
def forward(self, image: List[Image.Image]):
|
93 |
+
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
|
94 |
+
with torch.inference_mode():
|
95 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
|
96 |
+
output_logits = self.model(self.input_ids.repeat(image_tensor.shape[0], 1),
|
97 |
+
images=image_tensor)["logits"][:,-1, self.preferential_ids_]
|
98 |
+
|
99 |
+
return torch.softmax(output_logits, -1) #@ self.weight_tensor
|
100 |
+
|
101 |
+
class QAlignVideoScorer(nn.Module):
|
102 |
+
def __init__(self, pretrained="q-future/one-align", device="cuda:0", tokenizer=None, model=None, image_processor=None):
|
103 |
+
super().__init__()
|
104 |
+
if model is None:
|
105 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
106 |
+
prompt = "USER: How would you rate the quality of this video?\n<|image|>\nASSISTANT: The quality of the video is"
|
107 |
+
|
108 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
109 |
+
self.weight_tensor = torch.Tensor([1,0.75,0.5,0.25,0.]).half().to(model.device)
|
110 |
+
|
111 |
+
self.tokenizer = tokenizer
|
112 |
+
self.model = model
|
113 |
+
self.image_processor = image_processor
|
114 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
115 |
+
|
116 |
+
def expand2square(self, pil_img, background_color):
|
117 |
+
width, height = pil_img.size
|
118 |
+
if width == height:
|
119 |
+
return pil_img
|
120 |
+
elif width > height:
|
121 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
122 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
123 |
+
return result
|
124 |
+
else:
|
125 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
126 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
127 |
+
return result
|
128 |
+
|
129 |
+
def forward(self, video: List[List[Image.Image]]):
|
130 |
+
video = [[self.expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in video]
|
131 |
+
with torch.inference_mode():
|
132 |
+
video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
|
133 |
+
output_logits = self.model(self.input_ids.repeat(len(video_tensors), 1),
|
134 |
+
images=video_tensors)["logits"][:,-1, self.preferential_ids_]
|
135 |
+
return torch.softmax(output_logits, -1) #@ self.weight_tensor
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
import argparse
|
140 |
+
|
141 |
+
parser = argparse.ArgumentParser()
|
142 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
143 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
144 |
+
parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
|
145 |
+
parser.add_argument("--aesthetic", action="store_true")
|
146 |
+
parser.add_argument("--video", action="store_true")
|
147 |
+
args = parser.parse_args()
|
148 |
+
|
149 |
+
if args.video:
|
150 |
+
scorer = QAlignVideoScorer(pretrained=args.model_path, device=args.device)
|
151 |
+
print(scorer([load_video(args.img_path)]).tolist())
|
152 |
+
else:
|
153 |
+
scorer = QAlignScorer(pretrained=args.model_path, device=args.device) if not args.aesthetic else QAlignAestheticScorer(pretrained=args.model_path, device=args.device)
|
154 |
+
print(scorer([Image.open(args.img_path)]).tolist())
|
155 |
+
|
q_align/evaluate/vqa_eval.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from q_align.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
5 |
+
from q_align.conversation import conv_templates, SeparatorStyle
|
6 |
+
from q_align.model.builder import load_pretrained_model
|
7 |
+
from q_align.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from PIL import Image
|
13 |
+
from io import BytesIO
|
14 |
+
from transformers import TextStreamer
|
15 |
+
|
16 |
+
|
17 |
+
from scipy.stats import spearmanr, pearsonr
|
18 |
+
|
19 |
+
import json
|
20 |
+
from tqdm import tqdm
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
import os
|
24 |
+
|
25 |
+
def wa5(logits):
|
26 |
+
import numpy as np
|
27 |
+
logprobs = np.array([logits["excellent"], logits["good"], logits["fair"], logits["poor"], logits["bad"]])
|
28 |
+
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
|
29 |
+
return np.inner(probs, np.array([1,0.75,0.5,0.25,0.]))
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
def disable_torch_init():
|
34 |
+
"""
|
35 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
36 |
+
"""
|
37 |
+
import torch
|
38 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
39 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
40 |
+
|
41 |
+
|
42 |
+
def load_video(video_file):
|
43 |
+
from decord import VideoReader
|
44 |
+
vr = VideoReader(video_file)
|
45 |
+
|
46 |
+
# Get video frame rate
|
47 |
+
fps = vr.get_avg_fps()
|
48 |
+
|
49 |
+
# Calculate frame indices for 1fps
|
50 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
51 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
52 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
53 |
+
|
54 |
+
|
55 |
+
def main(args):
|
56 |
+
# Model
|
57 |
+
disable_torch_init()
|
58 |
+
|
59 |
+
model_name = get_model_name_from_path(args.model_path)
|
60 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
61 |
+
|
62 |
+
|
63 |
+
import json
|
64 |
+
|
65 |
+
|
66 |
+
image_paths = [
|
67 |
+
#"playground/data/",
|
68 |
+
#"playground/data/",
|
69 |
+
"playground/data/KoNViD_1k_videos/",
|
70 |
+
"playground/data/maxwell/",
|
71 |
+
|
72 |
+
]
|
73 |
+
|
74 |
+
json_prefix = "playground/data/test_jsons/"
|
75 |
+
jsons = [
|
76 |
+
#json_prefix + "test_lsvq.json",
|
77 |
+
#json_prefix + "test_lsvq_1080p.json",
|
78 |
+
json_prefix + "konvid.json",
|
79 |
+
json_prefix + "maxwell_test.json",
|
80 |
+
]
|
81 |
+
|
82 |
+
os.makedirs(f"results/{args.model_path}/", exist_ok=True)
|
83 |
+
|
84 |
+
|
85 |
+
conv_mode = "mplug_owl2"
|
86 |
+
|
87 |
+
inp = "How would you rate the quality of this video?"
|
88 |
+
|
89 |
+
conv = conv_templates[conv_mode].copy()
|
90 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
91 |
+
conv.append_message(conv.roles[0], inp)
|
92 |
+
image = None
|
93 |
+
|
94 |
+
conv.append_message(conv.roles[1], None)
|
95 |
+
prompt = conv.get_prompt() + " The quality of the video is"
|
96 |
+
|
97 |
+
toks = ["good", "poor", "high", "fair", "low", "excellent", "bad", "fine", "moderate", "decent", "average", "medium", "acceptable"]
|
98 |
+
print(toks)
|
99 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
100 |
+
print(ids_)
|
101 |
+
|
102 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(args.device)
|
103 |
+
|
104 |
+
for image_path, json_ in zip(image_paths, jsons):
|
105 |
+
with open(json_) as f:
|
106 |
+
iqadata = json.load(f)
|
107 |
+
prs, gts = [], []
|
108 |
+
for i, llddata in enumerate(tqdm(iqadata, desc="Evaluating [{}]".format(json_.split("/")[-1]))):
|
109 |
+
try:
|
110 |
+
try:
|
111 |
+
filename = llddata["img_path"]
|
112 |
+
except:
|
113 |
+
filename = llddata["image"]
|
114 |
+
llddata["logits"] = defaultdict(float)
|
115 |
+
|
116 |
+
image = load_video(image_path + filename)
|
117 |
+
def expand2square(pil_img, background_color):
|
118 |
+
width, height = pil_img.size
|
119 |
+
if width == height:
|
120 |
+
return pil_img
|
121 |
+
elif width > height:
|
122 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
123 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
124 |
+
return result
|
125 |
+
else:
|
126 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
127 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
128 |
+
return result
|
129 |
+
image = [expand2square(img, tuple(int(x*255) for x in image_processor.image_mean)) for img in image]
|
130 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(args.device)
|
131 |
+
|
132 |
+
if True:
|
133 |
+
with torch.inference_mode():
|
134 |
+
output_logits = model(input_ids,
|
135 |
+
images=[image_tensor])["logits"][:,-1]
|
136 |
+
for tok, id_ in zip(toks, ids_):
|
137 |
+
llddata["logits"][tok] += output_logits.mean(0)[id_].item()
|
138 |
+
llddata["score"] = wa5(llddata["logits"])
|
139 |
+
# print(llddata)
|
140 |
+
prs.append(llddata["score"])
|
141 |
+
gts.append(llddata["gt_score"])
|
142 |
+
# print(llddata)
|
143 |
+
json_ = json_.replace("combined/", "combined-")
|
144 |
+
with open(f"results/{args.model_path}/2{json_.split('/')[-1]}", "a") as wf:
|
145 |
+
json.dump(llddata, wf)
|
146 |
+
|
147 |
+
if i > 0 and i % 200 == 0:
|
148 |
+
print(spearmanr(prs,gts)[0], pearsonr(prs,gts)[0])
|
149 |
+
except:
|
150 |
+
continue
|
151 |
+
print("Spearmanr", spearmanr(prs,gts)[0], "Pearson", pearsonr(prs,gts)[0])
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
parser = argparse.ArgumentParser()
|
156 |
+
parser.add_argument("--model-path", type=str, default="q-future/one-align")
|
157 |
+
parser.add_argument("--model-base", type=str, default=None)
|
158 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
159 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
160 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
161 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
162 |
+
parser.add_argument("--load-8bit", action="store_true")
|
163 |
+
parser.add_argument("--load-4bit", action="store_true")
|
164 |
+
parser.add_argument("--debug", action="store_true")
|
165 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
166 |
+
args = parser.parse_args()
|
167 |
+
main(args)
|
q_align/mm_utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from q_align.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
|
8 |
+
from icecream import ic
|
9 |
+
|
10 |
+
|
11 |
+
def load_image_from_base64(image):
|
12 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
13 |
+
|
14 |
+
|
15 |
+
def expand2square(pil_img, background_color):
|
16 |
+
width, height = pil_img.size
|
17 |
+
if width == height:
|
18 |
+
return pil_img
|
19 |
+
elif width > height:
|
20 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
21 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
22 |
+
return result
|
23 |
+
else:
|
24 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
25 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
def process_images(images, image_processor, model_cfg=None):
|
30 |
+
if model_cfg is not None:
|
31 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
32 |
+
else:
|
33 |
+
image_aspect_ratio = 'resize'
|
34 |
+
new_images = []
|
35 |
+
if image_aspect_ratio == 'pad':
|
36 |
+
for image in images:
|
37 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
38 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
39 |
+
new_images.append(image)
|
40 |
+
elif image_aspect_ratio == 'resize':
|
41 |
+
for image in images:
|
42 |
+
max_edge = max(image.size)
|
43 |
+
image = image.resize((max_edge, max_edge))
|
44 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
45 |
+
new_images.append(image)
|
46 |
+
else:
|
47 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
48 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
49 |
+
new_images = torch.stack(new_images, dim=0)
|
50 |
+
return new_images
|
51 |
+
|
52 |
+
|
53 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
54 |
+
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
55 |
+
|
56 |
+
def insert_separator(X, sep):
|
57 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
58 |
+
|
59 |
+
input_ids = []
|
60 |
+
offset = 0
|
61 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
62 |
+
offset = 1
|
63 |
+
input_ids.append(prompt_chunks[0][0])
|
64 |
+
|
65 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
66 |
+
input_ids.extend(x[offset:])
|
67 |
+
|
68 |
+
if return_tensors is not None:
|
69 |
+
if return_tensors == 'pt':
|
70 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
71 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
72 |
+
return input_ids
|
73 |
+
|
74 |
+
|
75 |
+
def get_model_name_from_path(model_path):
|
76 |
+
model_path = model_path.strip("/")
|
77 |
+
model_paths = model_path.split("/")
|
78 |
+
if model_paths[-1].startswith('checkpoint-'):
|
79 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
80 |
+
else:
|
81 |
+
return model_paths[-1]
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
87 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
88 |
+
self.keywords = keywords
|
89 |
+
self.keyword_ids = []
|
90 |
+
self.max_keyword_len = 0
|
91 |
+
for keyword in keywords:
|
92 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
93 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
94 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
95 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
96 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
97 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
98 |
+
self.tokenizer = tokenizer
|
99 |
+
self.start_len = input_ids.shape[1]
|
100 |
+
|
101 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
102 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
103 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
104 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
105 |
+
for keyword_id in self.keyword_ids:
|
106 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
107 |
+
return True
|
108 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
109 |
+
for keyword in self.keywords:
|
110 |
+
if keyword in outputs:
|
111 |
+
return True
|
112 |
+
return False
|
q_align/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
|
2 |
+
from .configuration_mplug_owl2 import MPLUGOwl2Config
|
q_align/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (297 Bytes). View file
|
|
q_align/model/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (318 Bytes). View file
|
|
q_align/model/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (268 Bytes). View file
|
|
q_align/model/__pycache__/builder.cpython-310.pyc
ADDED
Binary file (3.8 kB). View file
|
|
q_align/model/__pycache__/builder.cpython-311.pyc
ADDED
Binary file (7.42 kB). View file
|
|
q_align/model/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (3.81 kB). View file
|
|
q_align/model/__pycache__/configuration_mplug_owl2.cpython-310.pyc
ADDED
Binary file (13.2 kB). View file
|
|
q_align/model/__pycache__/configuration_mplug_owl2.cpython-311.pyc
ADDED
Binary file (16.7 kB). View file
|
|
q_align/model/__pycache__/configuration_mplug_owl2.cpython-39.pyc
ADDED
Binary file (13.2 kB). View file
|
|
q_align/model/__pycache__/modeling_attn_mask_utils.cpython-310.pyc
ADDED
Binary file (7.59 kB). View file
|
|
q_align/model/__pycache__/modeling_attn_mask_utils.cpython-311.pyc
ADDED
Binary file (11.2 kB). View file
|
|
q_align/model/__pycache__/modeling_attn_mask_utils.cpython-39.pyc
ADDED
Binary file (7.49 kB). View file
|
|
q_align/model/__pycache__/modeling_llama2.cpython-310.pyc
ADDED
Binary file (13.3 kB). View file
|
|
q_align/model/__pycache__/modeling_llama2.cpython-311.pyc
ADDED
Binary file (24.3 kB). View file
|
|
q_align/model/__pycache__/modeling_llama2.cpython-39.pyc
ADDED
Binary file (13.1 kB). View file
|
|
q_align/model/__pycache__/modeling_mplug_owl2.cpython-310.pyc
ADDED
Binary file (9.69 kB). View file
|
|
q_align/model/__pycache__/modeling_mplug_owl2.cpython-311.pyc
ADDED
Binary file (21 kB). View file
|
|
q_align/model/__pycache__/modeling_mplug_owl2.cpython-39.pyc
ADDED
Binary file (9.63 kB). View file
|
|