finalf0 commited on
Commit
5491537
1 Parent(s): d43a1f6
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # encoding: utf-8
3
+ import spaces
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import traceback
7
+ import re
8
+ import torch
9
+ import argparse
10
+ from transformers import AutoModel, AutoTokenizer
11
+ from chat import OmniLMM12B
12
+
13
+
14
+
15
+ # Load model
16
+ model_path = 'openbmb/RLAIF-V-12B'
17
+ model = OmniLMM12B(model_path)
18
+
19
+
20
+ ERROR_MSG = "Error, please retry"
21
+ model_name = 'RLAIF-V-12B'
22
+
23
+ form_radio = {
24
+ 'choices': ['Beam Search', 'Sampling'],
25
+ #'value': 'Beam Search',
26
+ 'value': 'Sampling',
27
+ 'interactive': True,
28
+ 'label': 'Decode Type'
29
+ }
30
+ # Beam Form
31
+ num_beams_slider = {
32
+ 'minimum': 0,
33
+ 'maximum': 5,
34
+ 'value': 3,
35
+ 'step': 1,
36
+ 'interactive': True,
37
+ 'label': 'Num Beams'
38
+ }
39
+ repetition_penalty_slider = {
40
+ 'minimum': 0,
41
+ 'maximum': 3,
42
+ 'value': 1.2,
43
+ 'step': 0.01,
44
+ 'interactive': True,
45
+ 'label': 'Repetition Penalty'
46
+ }
47
+ repetition_penalty_slider2 = {
48
+ 'minimum': 0,
49
+ 'maximum': 3,
50
+ 'value': 1.05,
51
+ 'step': 0.01,
52
+ 'interactive': True,
53
+ 'label': 'Repetition Penalty'
54
+ }
55
+ max_new_tokens_slider = {
56
+ 'minimum': 1,
57
+ 'maximum': 4096,
58
+ 'value': 1024,
59
+ 'step': 1,
60
+ 'interactive': True,
61
+ 'label': 'Max New Tokens'
62
+ }
63
+
64
+ top_p_slider = {
65
+ 'minimum': 0,
66
+ 'maximum': 1,
67
+ 'value': 0.8,
68
+ 'step': 0.05,
69
+ 'interactive': True,
70
+ 'label': 'Top P'
71
+ }
72
+ top_k_slider = {
73
+ 'minimum': 0,
74
+ 'maximum': 200,
75
+ 'value': 100,
76
+ 'step': 1,
77
+ 'interactive': True,
78
+ 'label': 'Top K'
79
+ }
80
+ temperature_slider = {
81
+ 'minimum': 0,
82
+ 'maximum': 2,
83
+ 'value': 0.7,
84
+ 'step': 0.05,
85
+ 'interactive': True,
86
+ 'label': 'Temperature'
87
+ }
88
+
89
+
90
+ def create_component(params, comp='Slider'):
91
+ if comp == 'Slider':
92
+ return gr.Slider(
93
+ minimum=params['minimum'],
94
+ maximum=params['maximum'],
95
+ value=params['value'],
96
+ step=params['step'],
97
+ interactive=params['interactive'],
98
+ label=params['label']
99
+ )
100
+ elif comp == 'Radio':
101
+ return gr.Radio(
102
+ choices=params['choices'],
103
+ value=params['value'],
104
+ interactive=params['interactive'],
105
+ label=params['label']
106
+ )
107
+ elif comp == 'Button':
108
+ return gr.Button(
109
+ value=params['value'],
110
+ interactive=True
111
+ )
112
+
113
+ @spaces.GPU(duration=120)
114
+ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
115
+ if img is None:
116
+ return -1, "Error, invalid image, please upload a new image", None, None
117
+ try:
118
+ image = img.convert('RGB')
119
+ answer = model.chat(
120
+ image=image,
121
+ msgs=msgs,
122
+ )
123
+ return 0, answer, None, None
124
+ except Exception as err:
125
+ print(err)
126
+ traceback.print_exc()
127
+ return -1, ERROR_MSG, None, None
128
+
129
+
130
+ def upload_img(image, _chatbot, _app_session):
131
+ image = Image.fromarray(image)
132
+
133
+ _app_session['sts']=None
134
+ _app_session['ctx']=[]
135
+ _app_session['img']=image
136
+ _chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
137
+ return _chatbot, _app_session
138
+
139
+
140
+ def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
141
+ if _app_cfg.get('ctx', None) is None:
142
+ _chat_bot.append((_question, 'Please upload an image to start'))
143
+ return '', _chat_bot, _app_cfg
144
+
145
+ _context = _app_cfg['ctx'].copy()
146
+ if _context:
147
+ _context.append({"role": "user", "content": _question})
148
+ else:
149
+ _context = [{"role": "user", "content": _question}]
150
+ print('<User>:', _question)
151
+
152
+ if params_form == 'Beam Search':
153
+ params = {
154
+ 'sampling': False,
155
+ 'num_beams': num_beams,
156
+ 'repetition_penalty': repetition_penalty,
157
+ "max_new_tokens": 896
158
+ }
159
+ else:
160
+ params = {
161
+ 'sampling': True,
162
+ 'top_p': top_p,
163
+ 'top_k': top_k,
164
+ 'temperature': temperature,
165
+ 'repetition_penalty': repetition_penalty_2,
166
+ "max_new_tokens": 896
167
+ }
168
+ code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
169
+ print('<Assistant>:', _answer)
170
+
171
+ _context.append({"role": "assistant", "content": _answer})
172
+ _chat_bot.append((_question, _answer))
173
+ if code == 0:
174
+ _app_cfg['ctx']=_context
175
+ _app_cfg['sts']=sts
176
+ return '', _chat_bot, _app_cfg
177
+
178
+
179
+ def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
180
+ if len(_chat_bot) <= 1:
181
+ _chat_bot.append(('Regenerate', 'No question for regeneration.'))
182
+ return '', _chat_bot, _app_cfg
183
+ elif _chat_bot[-1][0] == 'Regenerate':
184
+ return '', _chat_bot, _app_cfg
185
+ else:
186
+ _question = _chat_bot[-1][0]
187
+ _chat_bot = _chat_bot[:-1]
188
+ _app_cfg['ctx'] = _app_cfg['ctx'][:-2]
189
+ return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
190
+
191
+
192
+
193
+ with gr.Blocks() as demo:
194
+ with gr.Row():
195
+ with gr.Column(scale=1, min_width=300):
196
+ params_form = create_component(form_radio, comp='Radio', visible=True)
197
+ with gr.Accordion("Beam Search") as beams_according:
198
+ num_beams = create_component(num_beams_slider)
199
+ repetition_penalty = create_component(repetition_penalty_slider)
200
+ with gr.Accordion("Sampling") as sampling_according:
201
+ top_p = create_component(top_p_slider)
202
+ top_k = create_component(top_k_slider)
203
+ temperature = create_component(temperature_slider)
204
+ repetition_penalty_2 = create_component(repetition_penalty_slider2)
205
+ regenerate = create_component({'value': 'Regenerate'}, comp='Button')
206
+ with gr.Column(scale=3, min_width=500):
207
+ app_session = gr.State({'sts':None,'ctx':None,'img':None})
208
+ bt_pic = gr.Image(label="Upload an image to start")
209
+ chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
210
+ txt_message = gr.Textbox(label="Input text")
211
+
212
+ regenerate.click(
213
+ regenerate_button_clicked,
214
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
215
+ [txt_message, chat_bot, app_session]
216
+ )
217
+ txt_message.submit(
218
+ respond,
219
+ [txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
220
+ [txt_message, chat_bot, app_session]
221
+ )
222
+ bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
223
+
224
+ # launch
225
+ #demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
226
+ demo.launch()
227
+
chat.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
+ #from accelerate import load_checkpoint_and_dispatch, init_empty_weights
8
+ from transformers import AutoTokenizer, AutoModel
9
+
10
+ from omnilmm.utils import disable_torch_init
11
+ from omnilmm.model.omnilmm import OmniLMMForCausalLM
12
+ from omnilmm.model.utils import build_transform
13
+ from omnilmm.train.train_utils import omni_preprocess
14
+
15
+ DEFAULT_IMAGE_TOKEN = "<image>"
16
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
17
+ DEFAULT_IM_START_TOKEN = "<im_start>"
18
+ DEFAULT_IM_END_TOKEN = "<im_end>"
19
+
20
+
21
+
22
+ def init_omni_lmm(model_path):
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ disable_torch_init()
25
+ model_name = os.path.expanduser(model_path)
26
+ print(f'Load omni_lmm model and tokenizer from {model_name}')
27
+ tokenizer = AutoTokenizer.from_pretrained(
28
+ model_name, model_max_length=4096)
29
+
30
+ if False:
31
+ # model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
32
+ with init_empty_weights():
33
+ model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
34
+ model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
35
+ device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
36
+ )
37
+ else:
38
+ model = OmniLMMForCausalLM.from_pretrained(
39
+ model_name, tune_clip=True, torch_dtype=torch.bfloat16
40
+ ).to(device='cuda', dtype=torch.bfloat16)
41
+
42
+ image_processor = build_transform(
43
+ is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
44
+
45
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
46
+ assert mm_use_im_start_end
47
+
48
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
49
+ DEFAULT_IM_END_TOKEN], special_tokens=True)
50
+
51
+
52
+ vision_config = model.model.vision_config
53
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
54
+ [DEFAULT_IMAGE_PATCH_TOKEN])[0]
55
+ vision_config.use_im_start_end = mm_use_im_start_end
56
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
57
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
58
+ image_token_len = model.model.config.num_query
59
+
60
+ return model, image_processor, image_token_len, tokenizer
61
+
62
+ def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
63
+ if '<image>' in question_text[0]['content']:
64
+ question_text[0]['content'] = question_text[0]['content'].replace(
65
+ '<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
66
+ else:
67
+ question_text[0]['content'] = im_st_token + im_patch_token * \
68
+ image_token_len + im_ed_token + '\n' + question_text[0]['content']
69
+ return question_text
70
+
71
+ def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
72
+ question = expand_question_into_multimodal(
73
+ question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
74
+
75
+ conversation = question
76
+ data_dict = omni_preprocess(sources=[conversation],
77
+ tokenizer=tokenizer,
78
+ generation=True)
79
+
80
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
81
+ labels=data_dict["labels"][0])
82
+ return data_dict
83
+
84
+
85
+
86
+ class OmniLMM12B:
87
+ def __init__(self, model_path) -> None:
88
+ model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
89
+ self.model = model
90
+ self.image_token_len = image_token_len
91
+ self.image_transform = img_processor
92
+ self.tokenizer = tokenizer
93
+ self.model.eval()
94
+
95
+ def decode(self, image, input_ids):
96
+ with torch.inference_mode():
97
+ output = self.model.generate_vllm(
98
+ input_ids=input_ids.unsqueeze(0).cuda(),
99
+ images=image.unsqueeze(0).half().cuda(),
100
+ temperature=0.6,
101
+ max_new_tokens=1024,
102
+ # num_beams=num_beams,
103
+ do_sample=True,
104
+ output_scores=True,
105
+ return_dict_in_generate=True,
106
+ repetition_penalty=1.1,
107
+ top_k=30,
108
+ top_p=0.9,
109
+ )
110
+
111
+ response = self.tokenizer.decode(
112
+ output.sequences[0], skip_special_tokens=True)
113
+ response = response.strip()
114
+ return response
115
+
116
+ def chat(self, image, msgs):
117
+ #image = input['image']
118
+ #msgs = json.loads(input['question'])
119
+ input_ids = wrap_question_for_omni_lmm(
120
+ msgs, self.image_token_len, self.tokenizer)['input_ids']
121
+ input_ids = torch.as_tensor(input_ids)
122
+ #print('input_ids', input_ids)
123
+ image = self.image_transform(image)
124
+
125
+ out = self.decode(image, input_ids)
126
+
127
+ return out
128
+
129
+
130
+ def img2base64(file_name):
131
+ with open(file_name, 'rb') as f:
132
+ encoded_string = base64.b64encode(f.read())
133
+ return encoded_string
134
+
135
+ class MiniCPMV:
136
+ def __init__(self, model_path) -> None:
137
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
138
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
139
+ self.model.eval().cuda()
140
+
141
+ def chat(self, input):
142
+ try:
143
+ image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
144
+ except Exception as e:
145
+ return "Image decode error"
146
+
147
+ msgs = json.loads(input['question'])
148
+
149
+ answer, context, _ = self.model.chat(
150
+ image=image,
151
+ msgs=msgs,
152
+ context=None,
153
+ tokenizer=self.tokenizer,
154
+ sampling=True,
155
+ temperature=0.7
156
+ )
157
+ return answer
158
+
159
+ class MiniCPMV2_5:
160
+ def __init__(self, model_path) -> None:
161
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
162
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
163
+ self.model.eval().cuda()
164
+
165
+ def chat(self, input):
166
+ try:
167
+ image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
168
+ except Exception as e:
169
+ return "Image decode error"
170
+
171
+ msgs = json.loads(input['question'])
172
+
173
+ answer = self.model.chat(
174
+ image=image,
175
+ msgs=msgs,
176
+ tokenizer=self.tokenizer,
177
+ sampling=True,
178
+ temperature=0.7
179
+ )
180
+ return answer
181
+
182
+
183
+ class MiniCPMVChat:
184
+ def __init__(self, model_path) -> None:
185
+ if '12B' in model_path:
186
+ self.model = OmniLMM12B(model_path)
187
+ elif 'MiniCPM-Llama3-V' in model_path:
188
+ self.model = MiniCPMV2_5(model_path)
189
+ else:
190
+ self.model = MiniCPMV(model_path)
191
+
192
+ def chat(self, input):
193
+ return self.model.chat(input)
194
+
195
+
196
+ if __name__ == '__main__':
197
+
198
+ model_path = 'openbmb/OmniLMM-12B'
199
+ chat_model = MiniCPMVChat(model_path)
200
+
201
+ im_64 = img2base64('./assets/worldmap_ck.jpg')
202
+
203
+ # first round chat
204
+ msgs = [{"role": "user", "content": "What is interesting about this image?"}]
205
+ input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
206
+ answer = chat_model.chat(input)
207
+ print(msgs[-1]["content"]+'\n', answer)
208
+
209
+ # second round chat
210
+ msgs.append({"role": "assistant", "content": answer})
211
+ msgs.append({"role": "user", "content": "Where is China in the image"})
212
+ input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
213
+ answer = chat_model.chat(input)
214
+ print(msgs[-1]["content"]+'\n', answer)
215
+
omnilmm/__init__.py ADDED
File without changes
omnilmm/constants.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
omnilmm/conversation.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class Conversation:
14
+ """A class that keeps all conversation history."""
15
+ system: str
16
+ roles: List[str]
17
+ messages: List[List[str]]
18
+ offset: int
19
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
20
+ sep: str = "###"
21
+ sep2: str = None
22
+ version: str = "Unknown"
23
+
24
+ skip_next: bool = False
25
+
26
+ def get_prompt(self):
27
+ if self.sep_style == SeparatorStyle.SINGLE:
28
+ ret = self.system + self.sep
29
+ for role, message in self.messages:
30
+ if message:
31
+ if type(message) is tuple:
32
+ message, _, _ = message
33
+ ret += role + ": " + message + self.sep
34
+ else:
35
+ ret += role + ":"
36
+ return ret
37
+ elif self.sep_style == SeparatorStyle.TWO:
38
+ seps = [self.sep, self.sep2]
39
+ ret = self.system + seps[0]
40
+ for i, (role, message) in enumerate(self.messages):
41
+ if message:
42
+ if type(message) is tuple:
43
+ message, _, _ = message
44
+ ret += role + ": " + message + seps[i % 2]
45
+ else:
46
+ ret += role + ":"
47
+ return ret
48
+ else:
49
+ raise ValueError(f"Invalid style: {self.sep_style}")
50
+
51
+ def append_message(self, role, message):
52
+ self.messages.append([role, message])
53
+
54
+ def get_images(self, return_pil=False):
55
+ images = []
56
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
57
+ if i % 2 == 0:
58
+ if type(msg) is tuple:
59
+ import base64
60
+ from io import BytesIO
61
+ from PIL import Image
62
+ msg, image, image_process_mode = msg
63
+ if image_process_mode == "Pad":
64
+ def expand2square(pil_img, background_color=(122, 116, 104)):
65
+ width, height = pil_img.size
66
+ if width == height:
67
+ return pil_img
68
+ elif width > height:
69
+ result = Image.new(
70
+ pil_img.mode, (width, width), background_color)
71
+ result.paste(
72
+ pil_img, (0, (width - height) // 2))
73
+ return result
74
+ else:
75
+ result = Image.new(
76
+ pil_img.mode, (height, height), background_color)
77
+ result.paste(
78
+ pil_img, ((height - width) // 2, 0))
79
+ return result
80
+ image = expand2square(image)
81
+ elif image_process_mode == "Crop":
82
+ pass
83
+ elif image_process_mode == "Resize":
84
+ image = image.resize((224, 224))
85
+ else:
86
+ raise ValueError(
87
+ f"Invalid image_process_mode: {image_process_mode}")
88
+ max_hw, min_hw = max(image.size), min(image.size)
89
+ aspect_ratio = max_hw / min_hw
90
+ max_len, min_len = 800, 400
91
+ shortest_edge = int(
92
+ min(max_len / aspect_ratio, min_len, min_hw))
93
+ longest_edge = int(shortest_edge * aspect_ratio)
94
+ W, H = image.size
95
+ if H > W:
96
+ H, W = longest_edge, shortest_edge
97
+ else:
98
+ H, W = shortest_edge, longest_edge
99
+ image = image.resize((W, H))
100
+ if return_pil:
101
+ images.append(image)
102
+ else:
103
+ buffered = BytesIO()
104
+ image.save(buffered, format="JPEG")
105
+ img_b64_str = base64.b64encode(
106
+ buffered.getvalue()).decode()
107
+ images.append(img_b64_str)
108
+ return images
109
+
110
+ def to_gradio_chatbot(self):
111
+ ret = []
112
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
113
+ if i % 2 == 0:
114
+ if type(msg) is tuple:
115
+ import base64
116
+ from io import BytesIO
117
+ msg, image, image_process_mode = msg
118
+ max_hw, min_hw = max(image.size), min(image.size)
119
+ aspect_ratio = max_hw / min_hw
120
+ max_len, min_len = 800, 400
121
+ shortest_edge = int(
122
+ min(max_len / aspect_ratio, min_len, min_hw))
123
+ longest_edge = int(shortest_edge * aspect_ratio)
124
+ W, H = image.size
125
+ if H > W:
126
+ H, W = longest_edge, shortest_edge
127
+ else:
128
+ H, W = shortest_edge, longest_edge
129
+ image = image.resize((W, H))
130
+ # image = image.resize((224, 224))
131
+ buffered = BytesIO()
132
+ image.save(buffered, format="JPEG")
133
+ img_b64_str = base64.b64encode(
134
+ buffered.getvalue()).decode()
135
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
136
+ msg = msg.replace('<image>', img_str)
137
+ ret.append([msg, None])
138
+ else:
139
+ ret[-1][-1] = msg
140
+ return ret
141
+
142
+ def copy(self):
143
+ return Conversation(
144
+ system=self.system,
145
+ roles=self.roles,
146
+ messages=[[x, y] for x, y in self.messages],
147
+ offset=self.offset,
148
+ sep_style=self.sep_style,
149
+ sep=self.sep,
150
+ sep2=self.sep2)
151
+
152
+ def dict(self):
153
+ if len(self.get_images()) > 0:
154
+ return {
155
+ "system": self.system,
156
+ "roles": self.roles,
157
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
158
+ "offset": self.offset,
159
+ "sep": self.sep,
160
+ "sep2": self.sep2,
161
+ }
162
+ return {
163
+ "system": self.system,
164
+ "roles": self.roles,
165
+ "messages": self.messages,
166
+ "offset": self.offset,
167
+ "sep": self.sep,
168
+ "sep2": self.sep2,
169
+ }
170
+
171
+
172
+ conv_v1 = Conversation(
173
+ system="A chat between a curious human and an artificial intelligence assistant. "
174
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
175
+ roles=("Human", "Assistant"),
176
+ messages=(
177
+ ("Human", "Give three tips for staying healthy."),
178
+ ("Assistant",
179
+ "Sure, here are three tips for staying healthy:\n"
180
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
181
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
182
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
183
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
184
+ "activities at least two days per week.\n"
185
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
186
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
187
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
188
+ "and aim to drink plenty of water throughout the day.\n"
189
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
190
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
191
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
192
+ "help improve the quality of your sleep.")
193
+ ),
194
+ offset=2,
195
+ sep_style=SeparatorStyle.SINGLE,
196
+ sep="###",
197
+ )
198
+
199
+ conv_v1_2 = Conversation(
200
+ system="A chat between a curious human and an artificial intelligence assistant. "
201
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
202
+ roles=("Human", "Assistant"),
203
+ messages=(
204
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
205
+ ("Assistant",
206
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
207
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
208
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
209
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
210
+ "renewable and non-renewable energy sources:\n"
211
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
212
+ "energy sources are finite and will eventually run out.\n"
213
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
214
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
215
+ "and other negative effects.\n"
216
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
217
+ "have lower operational costs than non-renewable sources.\n"
218
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
219
+ "locations than non-renewable sources.\n"
220
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
221
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
222
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
223
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
224
+ ),
225
+ offset=2,
226
+ sep_style=SeparatorStyle.SINGLE,
227
+ sep="###",
228
+ )
229
+
230
+ conv_vicuna_v1_1 = Conversation(
231
+ system="A chat between a curious user and an artificial intelligence assistant. "
232
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
233
+ roles=("USER", "ASSISTANT"),
234
+ version="v1",
235
+ messages=(),
236
+ offset=0,
237
+ sep_style=SeparatorStyle.TWO,
238
+ sep=" ",
239
+ sep2="</s>",
240
+ )
241
+
242
+ conv_bair_v1 = Conversation(
243
+ system="BEGINNING OF CONVERSATION:",
244
+ roles=("USER", "GPT"),
245
+ messages=(),
246
+ offset=0,
247
+ sep_style=SeparatorStyle.TWO,
248
+ sep=" ",
249
+ sep2="</s>",
250
+ )
251
+
252
+ simple_conv = Conversation(
253
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
254
+ "You are designed to assist human with a variety of tasks using natural language."
255
+ "Follow the instructions carefully.",
256
+ roles=("Human", "Assistant"),
257
+ messages=(
258
+ ("Human", "Hi!"),
259
+ ("Assistant", "Hi there! How can I help you today?\n")
260
+ ),
261
+ offset=2,
262
+ sep_style=SeparatorStyle.SINGLE,
263
+ sep="###",
264
+ )
265
+
266
+ simple_conv_multimodal = Conversation(
267
+ system="A chat between a curious user and an artificial intelligence assistant. "
268
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
269
+ roles=("Human", "Assistant"),
270
+ messages=(
271
+ ),
272
+ offset=0,
273
+ sep_style=SeparatorStyle.SINGLE,
274
+ sep="###",
275
+ )
276
+
277
+ simple_conv_legacy = Conversation(
278
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
279
+ "You are designed to assist human with a variety of tasks using natural language."
280
+ "Follow the instructions carefully.",
281
+ roles=("Human", "Assistant"),
282
+ messages=(
283
+ ("Human", "Hi!\n\n### Response:"),
284
+ ("Assistant", "Hi there! How can I help you today?\n")
285
+ ),
286
+ offset=2,
287
+ sep_style=SeparatorStyle.SINGLE,
288
+ sep="###",
289
+ )
290
+
291
+ conv_llava_v1 = Conversation(
292
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
293
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
294
+ "Follow the instructions carefully and explain your answers in detail.",
295
+ roles=("USER", "ASSISTANT"),
296
+ version="v1",
297
+ messages=(),
298
+ offset=0,
299
+ sep_style=SeparatorStyle.TWO,
300
+ sep=" ",
301
+ sep2="</s>",
302
+ )
303
+
304
+ default_conversation = conv_v1_2
305
+ conv_templates = {
306
+ "default": conv_v1_2,
307
+ "simple": simple_conv,
308
+ "simple_legacy": simple_conv_legacy,
309
+ "multimodal": simple_conv_multimodal,
310
+ "llava_v1": conv_llava_v1,
311
+
312
+ # fastchat
313
+ "v1": conv_v1_2,
314
+ "bair_v1": conv_bair_v1,
315
+ "vicuna_v1_1": conv_vicuna_v1_1,
316
+ }
317
+
318
+
319
+ if __name__ == "__main__":
320
+ print(default_conversation.get_prompt())
omnilmm/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .omnilmm import OmniLMMForCausalLM
omnilmm/model/omnilmm.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gc
3
+ import math
4
+ import timm
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn as nn
8
+ from torch.nn import CrossEntropyLoss
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ from transformers import AutoConfig, AutoModelForCausalLM
12
+ from transformers import MistralForCausalLM, MistralModel, MistralConfig
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+
15
+ from omnilmm.model.utils import build_transform
16
+ from omnilmm.model.resampler import Resampler
17
+
18
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
+ DEFAULT_IM_START_TOKEN = "<im_start>"
20
+ DEFAULT_IM_END_TOKEN = "<im_end>"
21
+
22
+
23
+ class OmniLMMConfig(MistralConfig):
24
+ model_type = "omnilmm"
25
+
26
+
27
+ class Identity(torch.nn.Identity):
28
+ def forward(self, input: Tensor, **kwargs) -> Tensor:
29
+ return super().forward(input)
30
+
31
+
32
+ def create_vision_module(config):
33
+ vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
34
+ pretrained=False,
35
+ num_classes=0,
36
+ dynamic_img_size=True,
37
+ dynamic_img_pad=True)
38
+
39
+ if isinstance(vision_tower, timm.models.VisionTransformer):
40
+ if vision_tower.attn_pool is not None:
41
+ vision_tower.attn_pool = Identity()
42
+
43
+ # use 2nd last layer's output
44
+ vision_tower.blocks[-1] = Identity()
45
+
46
+ embed_dim = config.hidden_size
47
+ resampler = Resampler(
48
+ grid_size=int(math.sqrt(config.num_query)),
49
+ embed_dim=embed_dim,
50
+ num_heads=embed_dim // 128,
51
+ kv_dim=vision_tower.embed_dim,
52
+ )
53
+ return vision_tower, resampler
54
+
55
+
56
+ class OmniLMMModel(MistralModel):
57
+ config_class = OmniLMMConfig
58
+
59
+ def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
60
+ super(OmniLMMModel, self).__init__(config)
61
+
62
+ if hasattr(config, "mm_vision_tower"):
63
+ vision_tower, resampler = create_vision_module(config)
64
+
65
+ # print(__file__, 'skip loading vision tower weights')
66
+
67
+ # HACK: for FSDP
68
+ self.vision_tower = [vision_tower]
69
+ self.resampler = resampler
70
+ if tune_clip:
71
+ self.vision_tower = self.vision_tower[0]
72
+
73
+ self.vision_config = lambda x: None
74
+
75
+ def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
76
+ self.config.mm_vision_tower = vision_tower
77
+ self.config.use_mm_proj = True
78
+ self.config.num_query = num_query
79
+ self.config.image_size = image_size
80
+
81
+ if not hasattr(self, 'vision_tower'):
82
+ vision_tower, resampler = create_vision_module(self.config)
83
+ state_dict = torch.load(
84
+ '/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
85
+ vision_tower.load_state_dict(state_dict, strict=False)
86
+ del state_dict
87
+ gc.collect()
88
+ else:
89
+ if isinstance(self.vision_tower, list):
90
+ vision_tower = self.vision_tower[0]
91
+ else:
92
+ vision_tower = self.vision_tower
93
+ resampler = self.resampler
94
+ self.vision_tower = vision_tower if tune_clip else [vision_tower]
95
+ self.resampler = resampler
96
+
97
+ train_img_transform = build_transform(
98
+ is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
99
+ eval_img_transform = build_transform(
100
+ is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
101
+
102
+ return dict(
103
+ image_processor=(train_img_transform, eval_img_transform),
104
+ image_token_len=num_query,
105
+ vision_config=self.vision_config
106
+ )
107
+
108
+ def get_vision_embedding(self, pixel_values):
109
+ if isinstance(self.vision_tower, list):
110
+ vision_tower = self.vision_tower[0] # HACK: for FSDP
111
+ else:
112
+ vision_tower = self.vision_tower
113
+
114
+ dtype = vision_tower.pos_embed.data.dtype
115
+ vision_embedding = vision_tower.forward_features(
116
+ pixel_values.type(dtype))
117
+ if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
118
+ vision_embedding = vision_embedding[:,
119
+ vision_tower.num_prefix_tokens:]
120
+ res = self.resampler(vision_embedding)
121
+ return res
122
+
123
+ def get_vllm_embedding(self, data):
124
+
125
+ if 'vision_hidden_states' not in data:
126
+ pixel_values_list = data['pixel_values']
127
+ vision_hidden_states = []
128
+ for pixel_values in pixel_values_list:
129
+ if len(pixel_values) > 0:
130
+ vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
131
+ else:
132
+ vision_hidden_states.append([])
133
+ else:
134
+ vision_hidden_states = data['vision_hidden_states']
135
+
136
+ #vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
137
+ inputs_embeds = self.embed_tokens(data['input_ids'])
138
+ vision_hidden_states = [i.type(inputs_embeds.dtype)
139
+ if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
140
+ ]
141
+
142
+
143
+ # HACK: replace back original embeddings for LLaVA pretraining
144
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
145
+
146
+ new_input_embeds = []
147
+ cur_image_idx = 0
148
+ for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
149
+ if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
150
+ # multimodal LLM, but the current sample is not multimodal
151
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
152
+ new_input_embeds.append(cur_input_embeds)
153
+ continue
154
+
155
+ if self.vision_config.use_im_start_end:
156
+ cur_image_features = vision_hidden_states[cur_image_idx]
157
+ num_patches = cur_image_features.shape[0]
158
+ if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
159
+ raise ValueError(
160
+ "The number of image start tokens and image end tokens should be the same.")
161
+ image_start_tokens = torch.where(
162
+ cur_input_ids == self.vision_config.im_start_token)[0]
163
+ for image_start_token_pos in image_start_tokens:
164
+ cur_image_features = vision_hidden_states[cur_image_idx].to(
165
+ device=cur_input_embeds.device)
166
+ num_patches = cur_image_features.shape[0]
167
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
168
+ raise ValueError(
169
+ "The image end token should follow the image start token.")
170
+ if orig_embeds_params is not None:
171
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
172
+ cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
173
+ else:
174
+ cur_new_input_embeds = torch.cat(
175
+ (cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
176
+ cur_image_idx += 1
177
+ new_input_embeds.append(cur_new_input_embeds)
178
+ else:
179
+ raise NotImplementedError
180
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
181
+
182
+ return inputs_embeds, vision_hidden_states
183
+
184
+ def forward(
185
+ self,
186
+ input_ids: torch.LongTensor = None,
187
+ attention_mask: Optional[torch.Tensor] = None,
188
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
189
+ inputs_embeds: Optional[torch.FloatTensor] = None,
190
+ use_cache: Optional[bool] = None,
191
+ output_attentions: Optional[bool] = None,
192
+ output_hidden_states: Optional[bool] = None,
193
+ images: Optional[torch.FloatTensor] = None,
194
+ return_dict: Optional[bool] = None,
195
+ **kwargs
196
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
197
+
198
+ # HACK: replace back original embeddings for LLaVA pretraining
199
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
200
+
201
+ if inputs_embeds is None and past_key_values is None:
202
+ inputs_embeds = self.embed_tokens(input_ids)
203
+
204
+ vision_tower = getattr(self, 'vision_tower', None)
205
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
206
+
207
+ if type(images) is list:
208
+ image_features = []
209
+ for image in images:
210
+ image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
211
+ 0]
212
+ image_features.append(image_forward_out)
213
+ else:
214
+ image_features = self.get_vision_embedding(images)
215
+
216
+ dummy_image_features = torch.zeros(
217
+ self.config.num_query,
218
+ self.config.hidden_size,
219
+ device=inputs_embeds.device,
220
+ dtype=inputs_embeds.dtype)
221
+
222
+ new_input_embeds = []
223
+ cur_image_idx = 0
224
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
225
+ if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
226
+ # multimodal LLM, but the current sample is not multimodal
227
+ cur_input_embeds = cur_input_embeds + \
228
+ (0. * dummy_image_features).sum()
229
+ new_input_embeds.append(cur_input_embeds)
230
+ continue
231
+
232
+ if self.vision_config.use_im_start_end:
233
+ cur_image_features = image_features[cur_image_idx]
234
+ num_patches = cur_image_features.shape[0]
235
+ if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
236
+ raise ValueError(
237
+ "The number of image start tokens and image end tokens should be the same.")
238
+ image_start_tokens = torch.where(
239
+ cur_input_ids == self.vision_config.im_start_token)[0]
240
+ for image_start_token_pos in image_start_tokens:
241
+ cur_image_features = image_features[cur_image_idx].to(
242
+ device=cur_input_embeds.device)
243
+ num_patches = cur_image_features.shape[0]
244
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
245
+ raise ValueError(
246
+ "The image end token should follow the image start token.")
247
+ if orig_embeds_params is not None:
248
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
249
+ cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
250
+ else:
251
+ cur_new_input_embeds = torch.cat(
252
+ (cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
253
+ cur_image_idx += 1
254
+ new_input_embeds.append(cur_new_input_embeds)
255
+ else:
256
+ raise NotImplementedError
257
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
258
+ input_ids = None
259
+
260
+ return super(OmniLMMModel, self).forward(
261
+ input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
262
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
263
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
264
+ return_dict=return_dict,
265
+ **kwargs
266
+ )
267
+
268
+
269
+ class OmniLMMForCausalLM(MistralForCausalLM):
270
+ config_class = OmniLMMConfig
271
+
272
+ def __init__(self, config, mm_vision_tower=None, tune_clip=True):
273
+ super(MistralForCausalLM, self).__init__(config)
274
+ self.model = OmniLMMModel(
275
+ config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
276
+
277
+ self.lm_head = nn.Linear(
278
+ config.hidden_size, config.vocab_size, bias=False)
279
+
280
+ # Initialize weights and apply final processing
281
+ self.post_init()
282
+
283
+ def forward(
284
+ self,
285
+ input_ids: torch.LongTensor = None,
286
+ attention_mask: Optional[torch.Tensor] = None,
287
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
288
+ inputs_embeds: Optional[torch.FloatTensor] = None,
289
+ labels: Optional[torch.LongTensor] = None,
290
+ use_cache: Optional[bool] = None,
291
+ output_attentions: Optional[bool] = None,
292
+ output_hidden_states: Optional[bool] = None,
293
+ images: Optional[torch.FloatTensor] = None,
294
+ return_dict: Optional[bool] = None,
295
+ **kwargs
296
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
297
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
298
+ output_hidden_states = (
299
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
300
+ )
301
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
302
+
303
+ # print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
304
+ # print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
305
+ # print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
306
+
307
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
308
+ outputs = self.model(
309
+ input_ids=input_ids,
310
+ attention_mask=attention_mask,
311
+ past_key_values=past_key_values,
312
+ inputs_embeds=inputs_embeds,
313
+ use_cache=use_cache,
314
+ output_attentions=output_attentions,
315
+ output_hidden_states=output_hidden_states,
316
+ return_dict=return_dict,
317
+ images=images,
318
+ **kwargs
319
+ )
320
+
321
+ hidden_states = outputs[0]
322
+ logits = self.lm_head(hidden_states)
323
+
324
+ loss = None
325
+ if labels is not None:
326
+ # Shift so that tokens < n predict n
327
+ shift_logits = logits[..., :-1, :].contiguous()
328
+ shift_labels = labels[..., 1:].contiguous()
329
+ # Flatten the tokens
330
+ loss_fct = CrossEntropyLoss()
331
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
332
+ shift_labels = shift_labels.view(-1)
333
+ # Enable model/pipeline parallelism
334
+ shift_labels = shift_labels.to(shift_logits.device)
335
+ loss = loss_fct(shift_logits, shift_labels)
336
+
337
+ if not return_dict:
338
+ output = (logits,) + outputs[1:]
339
+ return (loss,) + output if loss is not None else output
340
+
341
+ return CausalLMOutputWithPast(
342
+ loss=loss,
343
+ logits=logits,
344
+ past_key_values=outputs.past_key_values,
345
+ hidden_states=outputs.hidden_states,
346
+ attentions=outputs.attentions,
347
+ )
348
+
349
+ # TODO could be removed for generate_vllm()
350
+ def prepare_inputs_for_generation(
351
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
352
+ ):
353
+ if past_key_values:
354
+ input_ids = input_ids[:, -1:]
355
+
356
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
357
+ if inputs_embeds is not None and past_key_values is None:
358
+ model_inputs = {"inputs_embeds": inputs_embeds}
359
+ else:
360
+ model_inputs = {"input_ids": input_ids}
361
+
362
+ model_inputs.update(
363
+ {
364
+ "past_key_values": past_key_values,
365
+ "use_cache": kwargs.get("use_cache"),
366
+ "attention_mask": attention_mask,
367
+ "images": kwargs.get("images", None),
368
+ }
369
+ )
370
+ return model_inputs
371
+
372
+ def generate_vllm(
373
+ self,
374
+ input_ids: torch.LongTensor = None,
375
+ images: Optional[torch.FloatTensor] = None,
376
+ vision_hidden_states=None,
377
+ return_vision_hidden_states=False,
378
+ **kwargs
379
+ ):
380
+ model_inputs = {'input_ids': input_ids}
381
+ if vision_hidden_states is None:
382
+ model_inputs['pixel_values'] = images
383
+ else:
384
+ model_inputs['vision_hidden_states'] = vision_hidden_states
385
+
386
+ with torch.inference_mode():
387
+ inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
388
+
389
+ result = self.generate(
390
+ inputs_embeds=inputs_embeds,
391
+ **kwargs
392
+ )
393
+
394
+ if return_vision_hidden_states:
395
+ return result, vision_hidden_states
396
+
397
+ return result
398
+
399
+
400
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
401
+ tune_mm_mlp_adapter=False):
402
+ self.model.vision_config.use_im_start_end = mm_use_im_start_end
403
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
404
+ self.resize_token_embeddings(len(tokenizer))
405
+
406
+ if mm_use_im_start_end:
407
+ num_new_tokens = tokenizer.add_tokens(
408
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
409
+ self.resize_token_embeddings(len(tokenizer))
410
+ self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
411
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
412
+
413
+ if num_new_tokens > 0:
414
+ input_embeddings = self.get_input_embeddings().weight.data
415
+ output_embeddings = self.get_output_embeddings().weight.data
416
+
417
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
418
+ dim=0, keepdim=True)
419
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
420
+ dim=0, keepdim=True)
421
+
422
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
423
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
424
+
425
+ # for new sft data
426
+ num_new_tokens = tokenizer.add_tokens(
427
+ ['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
428
+ self.resize_token_embeddings(len(tokenizer))
429
+
430
+ if num_new_tokens > 0:
431
+ input_embeddings = self.get_input_embeddings().weight.data
432
+ output_embeddings = self.get_output_embeddings().weight.data
433
+
434
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
435
+ dim=0, keepdim=True)
436
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
437
+ dim=0, keepdim=True)
438
+
439
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
440
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
441
+
442
+ if tune_mm_mlp_adapter:
443
+ self.model.orig_embeds_params = [
444
+ self.get_input_embeddings().weight.data.clone().to(device=device)]
445
+ for p in self.get_input_embeddings().parameters():
446
+ p.requires_grad = True
447
+ for p in self.get_output_embeddings().parameters():
448
+ p.requires_grad = False
449
+
450
+ self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
451
+ [DEFAULT_IMAGE_PATCH_TOKEN])[0]
452
+ print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
453
+ # exit()
454
+
455
+
456
+ AutoConfig.register("omnilmm", OmniLMMConfig)
457
+ AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
omnilmm/model/resampler.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ from PIL import Image
12
+ from typing import Callable, Optional, Sequence, Tuple, List, Union
13
+ import numpy as np
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+
23
+ def get_abs_pos(abs_pos, tgt_size):
24
+ # abs_pos: L, C
25
+ # tgt_size: M
26
+ # return: M, C
27
+ src_size = int(math.sqrt(abs_pos.size(0)))
28
+ tgt_size = int(math.sqrt(tgt_size))
29
+ dtype = abs_pos.dtype
30
+
31
+ if src_size != tgt_size:
32
+ return F.interpolate(
33
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
34
+ size=(tgt_size, tgt_size),
35
+ mode="bicubic",
36
+ align_corners=False,
37
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
38
+ else:
39
+ return abs_pos
40
+
41
+
42
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
43
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
44
+ """
45
+ grid_size: int of the grid height and width
46
+ return:
47
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
48
+ """
49
+ grid_h = np.arange(grid_size, dtype=np.float32)
50
+ grid_w = np.arange(grid_size, dtype=np.float32)
51
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
52
+ grid = np.stack(grid, axis=0)
53
+
54
+ grid = grid.reshape([2, 1, grid_size, grid_size])
55
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
56
+ if cls_token:
57
+ pos_embed = np.concatenate(
58
+ [np.zeros([1, embed_dim]), pos_embed], axis=0)
59
+ return pos_embed
60
+
61
+
62
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
63
+ assert embed_dim % 2 == 0
64
+
65
+ # use half of dimensions to encode grid_h
66
+ emb_h = get_1d_sincos_pos_embed_from_grid(
67
+ embed_dim // 2, grid[0]) # (H*W, D/2)
68
+ emb_w = get_1d_sincos_pos_embed_from_grid(
69
+ embed_dim // 2, grid[1]) # (H*W, D/2)
70
+
71
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
72
+ return emb
73
+
74
+
75
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
76
+ """
77
+ embed_dim: output dimension for each position
78
+ pos: a list of positions to be encoded: size (M,)
79
+ out: (M, D)
80
+ """
81
+ assert embed_dim % 2 == 0
82
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
83
+ omega /= embed_dim / 2.
84
+ omega = 1. / 10000 ** omega # (D/2,)
85
+
86
+ pos = pos.reshape(-1) # (M,)
87
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
88
+
89
+ emb_sin = np.sin(out) # (M, D/2)
90
+ emb_cos = np.cos(out) # (M, D/2)
91
+
92
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
93
+ return emb
94
+
95
+
96
+ class Resampler(nn.Module):
97
+ """
98
+ A 2D perceiver-resampler network with one cross attention layers by
99
+ (grid_size**2) learnable queries and 2d sincos pos_emb
100
+ Outputs:
101
+ A tensor with the shape of (grid_size**2, embed_dim)
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ grid_size,
107
+ embed_dim,
108
+ num_heads,
109
+ kv_dim=None,
110
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)
111
+ ):
112
+ super().__init__()
113
+ self.num_queries = grid_size ** 2
114
+ self.embed_dim = embed_dim
115
+ self.num_heads = num_heads
116
+
117
+ self.pos_embed = nn.Parameter(
118
+ torch.from_numpy(get_2d_sincos_pos_embed(
119
+ embed_dim, grid_size)).float()
120
+ ).requires_grad_(False)
121
+
122
+ self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
123
+ trunc_normal_(self.query, std=.02)
124
+
125
+ if kv_dim is not None and kv_dim != embed_dim:
126
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
127
+ else:
128
+ self.kv_proj = nn.Identity()
129
+
130
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads)
131
+ self.ln_q = norm_layer(embed_dim)
132
+ self.ln_kv = norm_layer(embed_dim)
133
+
134
+ self.ln_post = norm_layer(embed_dim)
135
+ self.proj = nn.Parameter(
136
+ (embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
137
+
138
+ self.apply(self._init_weights)
139
+
140
+ def _init_weights(self, m):
141
+ if isinstance(m, nn.Linear):
142
+ trunc_normal_(m.weight, std=.02)
143
+ if isinstance(m, nn.Linear) and m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.LayerNorm):
146
+ nn.init.constant_(m.bias, 0)
147
+ nn.init.constant_(m.weight, 1.0)
148
+
149
+ def forward(self, x, attn_mask=None):
150
+
151
+ pos_embed = get_abs_pos(self.pos_embed, x.size(1))
152
+
153
+ x = self.kv_proj(x)
154
+ x = self.ln_kv(x).permute(1, 0, 2)
155
+
156
+ N = x.shape[1]
157
+ q = self.ln_q(self.query)
158
+ # print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
159
+ out = self.attn(
160
+ self._repeat(q, N) + self.pos_embed.unsqueeze(1),
161
+ x + pos_embed.unsqueeze(1),
162
+ x,
163
+ attn_mask=attn_mask)[0]
164
+ x = out.permute(1, 0, 2)
165
+
166
+ x = self.ln_post(x)
167
+ x = x @ self.proj
168
+ return x
169
+
170
+ def _repeat(self, query, N: int):
171
+ return query.unsqueeze(1).repeat(1, N, 1)
omnilmm/model/utils.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from timm.data.transforms import RandomResizedCropAndInterpolation
3
+ from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
4
+ from transformers import AutoConfig
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import torch.distributed as dist
8
+ import numpy as np
9
+ import pickle
10
+ import base64
11
+ import cv2
12
+ import os
13
+ import torch
14
+ from transformers import AutoConfig, StoppingCriteria
15
+
16
+ try:
17
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
18
+ except ImportError:
19
+ OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
20
+ OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
21
+
22
+
23
+ def auto_upgrade(config):
24
+ cfg = AutoConfig.from_pretrained(config)
25
+ if 'llava' in config and cfg.model_type != 'llava':
26
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
27
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
28
+ confirm = input(
29
+ "Please confirm that you want to upgrade the checkpoint. [Y/N]")
30
+ if confirm.lower() in ["y", "yes"]:
31
+ print("Upgrading checkpoint...")
32
+ assert len(cfg.architectures) == 1
33
+ setattr(cfg.__class__, "model_type", "llava")
34
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
35
+ cfg.save_pretrained(config)
36
+ print("Checkpoint upgraded.")
37
+ else:
38
+ print("Checkpoint upgrade aborted.")
39
+ exit(1)
40
+
41
+
42
+ class KeywordsStoppingCriteria(StoppingCriteria):
43
+ def __init__(self, keywords, tokenizer, input_ids):
44
+ self.keywords = keywords
45
+ self.tokenizer = tokenizer
46
+ self.start_len = None
47
+ self.input_ids = input_ids
48
+
49
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
50
+ if self.start_len is None:
51
+ self.start_len = self.input_ids.shape[1]
52
+ else:
53
+ outputs = self.tokenizer.batch_decode(
54
+ output_ids[:, self.start_len:], skip_special_tokens=True)[0]
55
+ for keyword in self.keywords:
56
+ if keyword in outputs:
57
+ return True
58
+ return False
59
+
60
+
61
+ def auto_upgrade(config):
62
+ cfg = AutoConfig.from_pretrained(config)
63
+ if 'llava' in config and cfg.model_type != 'llava':
64
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
65
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
66
+ confirm = input(
67
+ "Please confirm that you want to upgrade the checkpoint. [Y/N]")
68
+ if confirm.lower() in ["y", "yes"]:
69
+ print("Upgrading checkpoint...")
70
+ assert len(cfg.architectures) == 1
71
+ setattr(cfg.__class__, "model_type", "llava")
72
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
73
+ cfg.save_pretrained(config)
74
+ print("Checkpoint upgraded.")
75
+ else:
76
+ print("Checkpoint upgrade aborted.")
77
+ exit(1)
78
+
79
+ # aug functions
80
+
81
+
82
+ def identity_func(img):
83
+ return img
84
+
85
+
86
+ def autocontrast_func(img, cutoff=0):
87
+ '''
88
+ same output as PIL.ImageOps.autocontrast
89
+ '''
90
+ n_bins = 256
91
+
92
+ def tune_channel(ch):
93
+ n = ch.size
94
+ cut = cutoff * n // 100
95
+ if cut == 0:
96
+ high, low = ch.max(), ch.min()
97
+ else:
98
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
99
+ low = np.argwhere(np.cumsum(hist) > cut)
100
+ low = 0 if low.shape[0] == 0 else low[0]
101
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
102
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
103
+ if high <= low:
104
+ table = np.arange(n_bins)
105
+ else:
106
+ scale = (n_bins - 1) / (high - low)
107
+ table = np.arange(n_bins) * scale - low * scale
108
+ table[table < 0] = 0
109
+ table[table > n_bins - 1] = n_bins - 1
110
+ table = table.clip(0, 255).astype(np.uint8)
111
+ return table[ch]
112
+
113
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
114
+ out = cv2.merge(channels)
115
+ return out
116
+
117
+
118
+ def equalize_func(img):
119
+ '''
120
+ same output as PIL.ImageOps.equalize
121
+ PIL's implementation is different from cv2.equalize
122
+ '''
123
+ n_bins = 256
124
+
125
+ def tune_channel(ch):
126
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
127
+ non_zero_hist = hist[hist != 0].reshape(-1)
128
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
129
+ if step == 0:
130
+ return ch
131
+ n = np.empty_like(hist)
132
+ n[0] = step // 2
133
+ n[1:] = hist[:-1]
134
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
135
+ return table[ch]
136
+
137
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
138
+ out = cv2.merge(channels)
139
+ return out
140
+
141
+
142
+ def rotate_func(img, degree, fill=(0, 0, 0)):
143
+ '''
144
+ like PIL, rotate by degree, not radians
145
+ '''
146
+ H, W = img.shape[0], img.shape[1]
147
+ center = W / 2, H / 2
148
+ M = cv2.getRotationMatrix2D(center, degree, 1)
149
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
150
+ return out
151
+
152
+
153
+ def solarize_func(img, thresh=128):
154
+ '''
155
+ same output as PIL.ImageOps.posterize
156
+ '''
157
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
158
+ table = table.clip(0, 255).astype(np.uint8)
159
+ out = table[img]
160
+ return out
161
+
162
+
163
+ def color_func(img, factor):
164
+ '''
165
+ same output as PIL.ImageEnhance.Color
166
+ '''
167
+ # implementation according to PIL definition, quite slow
168
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
169
+ # out = blend(degenerate, img, factor)
170
+ # M = (
171
+ # np.eye(3) * factor
172
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
173
+ # )[np.newaxis, np.newaxis, :]
174
+ M = (
175
+ np.float32([
176
+ [0.886, -0.114, -0.114],
177
+ [-0.587, 0.413, -0.587],
178
+ [-0.299, -0.299, 0.701]]) * factor
179
+ + np.float32([[0.114], [0.587], [0.299]])
180
+ )
181
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
182
+ return out
183
+
184
+
185
+ def contrast_func(img, factor):
186
+ """
187
+ same output as PIL.ImageEnhance.Contrast
188
+ """
189
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
190
+ table = np.array([(
191
+ el - mean) * factor + mean
192
+ for el in range(256)
193
+ ]).clip(0, 255).astype(np.uint8)
194
+ out = table[img]
195
+ return out
196
+
197
+
198
+ def brightness_func(img, factor):
199
+ '''
200
+ same output as PIL.ImageEnhance.Contrast
201
+ '''
202
+ table = (np.arange(256, dtype=np.float32) *
203
+ factor).clip(0, 255).astype(np.uint8)
204
+ out = table[img]
205
+ return out
206
+
207
+
208
+ def sharpness_func(img, factor):
209
+ '''
210
+ The differences the this result and PIL are all on the 4 boundaries, the center
211
+ areas are same
212
+ '''
213
+ kernel = np.ones((3, 3), dtype=np.float32)
214
+ kernel[1][1] = 5
215
+ kernel /= 13
216
+ degenerate = cv2.filter2D(img, -1, kernel)
217
+ if factor == 0.0:
218
+ out = degenerate
219
+ elif factor == 1.0:
220
+ out = img
221
+ else:
222
+ out = img.astype(np.float32)
223
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
224
+ out[1:-1, 1:-1, :] = degenerate + factor * \
225
+ (out[1:-1, 1:-1, :] - degenerate)
226
+ out = out.astype(np.uint8)
227
+ return out
228
+
229
+
230
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
231
+ H, W = img.shape[0], img.shape[1]
232
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
233
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
234
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
235
+ return out
236
+
237
+
238
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
239
+ '''
240
+ same output as PIL.Image.transform
241
+ '''
242
+ H, W = img.shape[0], img.shape[1]
243
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
244
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
245
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
246
+ return out
247
+
248
+
249
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
250
+ '''
251
+ same output as PIL.Image.transform
252
+ '''
253
+ H, W = img.shape[0], img.shape[1]
254
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
255
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
256
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
257
+ return out
258
+
259
+
260
+ def posterize_func(img, bits):
261
+ '''
262
+ same output as PIL.ImageOps.posterize
263
+ '''
264
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
265
+ return out
266
+
267
+
268
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
269
+ H, W = img.shape[0], img.shape[1]
270
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
271
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
272
+ flags=cv2.INTER_LINEAR).astype(np.uint8)
273
+ return out
274
+
275
+
276
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
277
+ replace = np.array(replace, dtype=np.uint8)
278
+ H, W = img.shape[0], img.shape[1]
279
+ rh, rw = np.random.random(2)
280
+ pad_size = pad_size // 2
281
+ ch, cw = int(rh * H), int(rw * W)
282
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
283
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
284
+ out = img.copy()
285
+ out[x1:x2, y1:y2, :] = replace
286
+ return out
287
+
288
+
289
+ # level to args
290
+ def enhance_level_to_args(MAX_LEVEL):
291
+ def level_to_args(level):
292
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
293
+ return level_to_args
294
+
295
+
296
+ def shear_level_to_args(MAX_LEVEL, replace_value):
297
+ def level_to_args(level):
298
+ level = (level / MAX_LEVEL) * 0.3
299
+ if np.random.random() > 0.5:
300
+ level = -level
301
+ return (level, replace_value)
302
+
303
+ return level_to_args
304
+
305
+
306
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
307
+ def level_to_args(level):
308
+ level = (level / MAX_LEVEL) * float(translate_const)
309
+ if np.random.random() > 0.5:
310
+ level = -level
311
+ return (level, replace_value)
312
+
313
+ return level_to_args
314
+
315
+
316
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
317
+ def level_to_args(level):
318
+ level = int((level / MAX_LEVEL) * cutout_const)
319
+ return (level, replace_value)
320
+
321
+ return level_to_args
322
+
323
+
324
+ def solarize_level_to_args(MAX_LEVEL):
325
+ def level_to_args(level):
326
+ level = int((level / MAX_LEVEL) * 256)
327
+ return (level, )
328
+ return level_to_args
329
+
330
+
331
+ def none_level_to_args(level):
332
+ return ()
333
+
334
+
335
+ def posterize_level_to_args(MAX_LEVEL):
336
+ def level_to_args(level):
337
+ level = int((level / MAX_LEVEL) * 4)
338
+ return (level, )
339
+ return level_to_args
340
+
341
+
342
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
343
+ def level_to_args(level):
344
+ level = (level / MAX_LEVEL) * 30
345
+ if np.random.random() < 0.5:
346
+ level = -level
347
+ return (level, replace_value)
348
+
349
+ return level_to_args
350
+
351
+
352
+ func_dict = {
353
+ 'Identity': identity_func,
354
+ 'AutoContrast': autocontrast_func,
355
+ 'Equalize': equalize_func,
356
+ 'Rotate': rotate_func,
357
+ 'Solarize': solarize_func,
358
+ 'Color': color_func,
359
+ 'Contrast': contrast_func,
360
+ 'Brightness': brightness_func,
361
+ 'Sharpness': sharpness_func,
362
+ 'ShearX': shear_x_func,
363
+ 'TranslateX': translate_x_func,
364
+ 'TranslateY': translate_y_func,
365
+ 'Posterize': posterize_func,
366
+ 'ShearY': shear_y_func,
367
+ }
368
+
369
+ translate_const = 10
370
+ MAX_LEVEL = 10
371
+ replace_value = (128, 128, 128)
372
+ arg_dict = {
373
+ 'Identity': none_level_to_args,
374
+ 'AutoContrast': none_level_to_args,
375
+ 'Equalize': none_level_to_args,
376
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
377
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
378
+ 'Color': enhance_level_to_args(MAX_LEVEL),
379
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
380
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
381
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
382
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
383
+ 'TranslateX': translate_level_to_args(
384
+ translate_const, MAX_LEVEL, replace_value
385
+ ),
386
+ 'TranslateY': translate_level_to_args(
387
+ translate_const, MAX_LEVEL, replace_value
388
+ ),
389
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
390
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
391
+ }
392
+
393
+
394
+ class RandomAugment(object):
395
+
396
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
397
+ self.N = N
398
+ self.M = M
399
+ self.isPIL = isPIL
400
+ if augs:
401
+ self.augs = augs
402
+ else:
403
+ self.augs = list(arg_dict.keys())
404
+
405
+ def get_random_ops(self):
406
+ sampled_ops = np.random.choice(self.augs, self.N)
407
+ return [(op, 0.5, self.M) for op in sampled_ops]
408
+
409
+ def __call__(self, img):
410
+ if self.isPIL:
411
+ img = np.array(img)
412
+ ops = self.get_random_ops()
413
+ for name, prob, level in ops:
414
+ if np.random.random() > prob:
415
+ continue
416
+ args = arg_dict[name](level)
417
+ img = func_dict[name](img, *args)
418
+ return img
419
+
420
+
421
+ def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
422
+ if std_mode == 'IMAGENET_INCEPTION':
423
+ mean = IMAGENET_INCEPTION_MEAN
424
+ std = IMAGENET_INCEPTION_STD
425
+ elif std_mode == 'OPENAI_CLIP':
426
+ mean = OPENAI_CLIP_MEAN
427
+ std = OPENAI_CLIP_STD
428
+ else:
429
+ raise NotImplementedError
430
+
431
+ if is_train:
432
+ crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
433
+ t = [
434
+ RandomResizedCropAndInterpolation(
435
+ input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
436
+ # transforms.RandomHorizontalFlip(),
437
+ ]
438
+ if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
439
+ print(f'@@@@@ Do random aug during training', flush=True)
440
+ t.append(
441
+ RandomAugment(
442
+ 2, 7, isPIL=True,
443
+ augs=[
444
+ 'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
445
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
446
+ ]))
447
+ else:
448
+ print(f'@@@@@ Skip random aug during training', flush=True)
449
+ t += [
450
+ transforms.ToTensor(),
451
+ transforms.Normalize(mean=mean, std=std),
452
+ ]
453
+ t = transforms.Compose(t)
454
+ else:
455
+ t = transforms.Compose([
456
+ transforms.Resize((input_size, input_size),
457
+ interpolation=transforms.InterpolationMode.BICUBIC),
458
+ transforms.ToTensor(),
459
+ transforms.Normalize(mean=mean, std=std)
460
+ ])
461
+
462
+ return t
463
+
464
+
465
+ def img2b64(img_path):
466
+ img = Image.open(img_path) # path to file
467
+ img_buffer = BytesIO()
468
+ img.save(img_buffer, format=img.format)
469
+ byte_data = img_buffer.getvalue()
470
+ base64_str = base64.b64encode(byte_data) # bytes
471
+ base64_str = base64_str.decode("utf-8") # str
472
+ return base64_str
473
+
474
+
475
+ def str2b64(str):
476
+ return base64.b64encode(str.encode('utf-8')).decode('utf-8')
477
+
478
+
479
+ def b642str(b64):
480
+ return base64.b64decode(b64).decode('utf-8')
481
+
482
+
483
+ def is_dist_avail_and_initialized():
484
+ if not dist.is_available():
485
+ return False
486
+ if not dist.is_initialized():
487
+ return False
488
+ return True
489
+
490
+
491
+ def get_world_size():
492
+ if not is_dist_avail_and_initialized():
493
+ return 1
494
+ return dist.get_world_size()
495
+
496
+
497
+ def get_rank():
498
+ if not is_dist_avail_and_initialized():
499
+ return 0
500
+ return dist.get_rank()
501
+
502
+
503
+ def all_gather(data):
504
+ """
505
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
506
+ Args:
507
+ data: any picklable object
508
+ Returns:
509
+ list[data]: list of data gathered from each rank
510
+ """
511
+ world_size = get_world_size()
512
+ if world_size == 1:
513
+ return [data]
514
+
515
+ # serialized to a Tensor
516
+ buffer = pickle.dumps(data)
517
+ storage = torch.ByteStorage.from_buffer(buffer)
518
+ tensor = torch.ByteTensor(storage).to("cuda")
519
+
520
+ # obtain Tensor size of each rank
521
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
522
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
523
+ dist.all_gather(size_list, local_size)
524
+ size_list = [int(size.item()) for size in size_list]
525
+ max_size = max(size_list)
526
+
527
+ # receiving Tensor from all ranks
528
+ # we pad the tensor because torch all_gather does not support
529
+ # gathering tensors of different shapes
530
+ tensor_list = []
531
+ for _ in size_list:
532
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
533
+ if local_size != max_size:
534
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
535
+ tensor = torch.cat((tensor, padding), dim=0)
536
+ dist.all_gather(tensor_list, tensor)
537
+
538
+ data_list = []
539
+ for size, tensor in zip(size_list, tensor_list):
540
+ buffer = tensor.cpu().numpy().tobytes()[:size]
541
+ data_list.append(pickle.loads(buffer))
542
+
543
+ return data_list
544
+
545
+
546
+ def mean(lst):
547
+ return sum(lst) / len(lst)
548
+
549
+
550
+ def stop_gradient_by_name(name: str):
551
+ def apply_fn(module):
552
+ if hasattr(module, name):
553
+ getattr(module, name).requires_grad_(False)
554
+
555
+ return apply_fn
omnilmm/train/train_utils.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import copy
4
+ import time
5
+
6
+ import torch
7
+ import warnings
8
+ import transformers
9
+
10
+ import numpy as np
11
+
12
+ from typing import Dict, Optional, Sequence
13
+ from omnilmm import conversation as conversation_lib
14
+
15
+ IGNORE_INDEX = -100
16
+ DEFAULT_IMAGE_TOKEN = "<image>"
17
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
18
+ DEFAULT_IM_START_TOKEN = "<im_start>"
19
+ DEFAULT_IM_END_TOKEN = "<im_end>"
20
+
21
+
22
+ def _tokenize_fn(strings: Sequence[str],
23
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
24
+ """Tokenize a list of strings."""
25
+ tokenized_list = [
26
+ tokenizer(
27
+ text,
28
+ return_tensors="pt",
29
+ padding="longest",
30
+ max_length=tokenizer.model_max_length,
31
+ truncation=True,
32
+ ) for text in strings
33
+ ]
34
+ input_ids = labels = [
35
+ tokenized.input_ids[0] for tokenized in tokenized_list
36
+ ]
37
+ input_ids_lens = labels_lens = [
38
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
39
+ for tokenized in tokenized_list
40
+ ]
41
+ return dict(
42
+ input_ids=input_ids,
43
+ labels=labels,
44
+ input_ids_lens=input_ids_lens,
45
+ labels_lens=labels_lens,
46
+ )
47
+
48
+
49
+
50
+ def omni_preprocess(sources,
51
+ tokenizer: transformers.PreTrainedTokenizer,
52
+ generation=False):
53
+ system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
54
+ ignore_index = -100
55
+
56
+ response_template = '\n<|assistant|>\n'
57
+ instruction_template = '\n<|user|>\n'
58
+ response_token_ids = tokenizer.encode(
59
+ response_template, add_special_tokens=False)
60
+ instruction_token_ids = tokenizer.encode(
61
+ instruction_template, add_special_tokens=False)
62
+
63
+ batch_input_ids = []
64
+ batch_labels = []
65
+ for i in range(len(sources)):
66
+ new_source = []
67
+ prev_role = 'unexpect'
68
+ for conv_turn in sources[i]:
69
+ role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
70
+ content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
71
+
72
+ role = 'user' if role == 'human' else role
73
+ role = 'assistant' if role == 'gpt' else role
74
+
75
+ assert role in ['user', 'assistant']
76
+ assert role != prev_role, f'role={role}, prev_role={prev_role}'
77
+ prev_role = role
78
+
79
+ new_turn = {
80
+ 'role': role,
81
+ 'content': content
82
+ }
83
+ new_source.append(new_turn)
84
+ if new_source[0]['role'] != 'system':
85
+ new_source.insert(0, {'role': 'system', 'content': system_content})
86
+
87
+ # TODO: this automatically add '\n' to the end
88
+ res_text = tokenizer.apply_chat_template(
89
+ new_source, tokenize=False, add_generation_prompt=generation)
90
+ if not generation:
91
+ res_text = res_text.strip()
92
+
93
+ conversations_tokenized = _tokenize_fn([res_text], tokenizer)
94
+ res_input_ids = conversations_tokenized["input_ids"][0]
95
+
96
+ # since labels and input_ids are reference towards the same object
97
+ res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
98
+
99
+ response_token_ids_idxs = []
100
+ human_token_ids_idxs = []
101
+
102
+ for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
103
+ # find the indexes of the start of a response.
104
+ if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
105
+ response_token_ids)].tolist()
106
+ ):
107
+ response_token_ids_idxs.append(
108
+ assistant_idx + len(response_token_ids))
109
+
110
+ if len(response_token_ids_idxs) == 0:
111
+ warnings.warn(
112
+ f"Could not find response key `{response_template}` in the "
113
+ f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
114
+ f'Raw text is @===>{res_text}<===@'
115
+ f'Raw source is @===>{new_source}<===@'
116
+ f"This instance will be ignored in loss calculation. "
117
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
118
+ )
119
+ res_labels[:] = ignore_index
120
+
121
+ human_token_ids = instruction_token_ids
122
+ for human_idx in np.where(res_labels == human_token_ids[0])[0]:
123
+ # find the indexes of the start of a human answer.
124
+ if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
125
+ human_token_ids_idxs.append(human_idx)
126
+
127
+ if len(human_token_ids_idxs) == 0:
128
+ warnings.warn(
129
+ f"Could not find instruction key `{instruction_template}` in the "
130
+ f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
131
+ f'Raw text is @===>{res_text}<===@'
132
+ f'Raw source is @===>{new_source}<===@'
133
+ f"This instance will be ignored in loss calculation. "
134
+ f"Note, if this happens often, consider increasing the `max_seq_length`."
135
+ )
136
+ res_labels[:] = ignore_index
137
+
138
+ for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
139
+ # Make pytorch loss function ignore all non response tokens
140
+ if idx != 0:
141
+ res_labels[start:end] = ignore_index
142
+ else:
143
+ res_labels[:end] = ignore_index
144
+
145
+ if len(response_token_ids_idxs) < len(human_token_ids_idxs):
146
+ res_labels[human_token_ids_idxs[-1]:] = ignore_index
147
+
148
+ batch_input_ids.append(res_input_ids)
149
+ batch_labels.append(res_labels)
150
+
151
+ return dict(input_ids=batch_input_ids, labels=batch_labels)
152
+
153
+
omnilmm/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from omnilmm.constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True)
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+
65
+ def __init__(self, logger, log_level=logging.INFO):
66
+ self.terminal = sys.stdout
67
+ self.logger = logger
68
+ self.log_level = log_level
69
+ self.linebuf = ''
70
+
71
+ def __getattr__(self, attr):
72
+ return getattr(self.terminal, attr)
73
+
74
+ def write(self, buf):
75
+ temp_linebuf = self.linebuf + buf
76
+ self.linebuf = ''
77
+ for line in temp_linebuf.splitlines(True):
78
+ # From the io.TextIOWrapper docs:
79
+ # On output, if newline is None, any '\n' characters written
80
+ # are translated to the system default line separator.
81
+ # By default sys.stdout.write() expects '\n' newlines and then
82
+ # translates them so this is still cross platform.
83
+ if line[-1] == '\n':
84
+ self.logger.log(self.log_level, line.rstrip())
85
+ else:
86
+ self.linebuf += line
87
+
88
+ def flush(self):
89
+ if self.linebuf != '':
90
+ self.logger.log(self.log_level, self.linebuf.rstrip())
91
+ self.linebuf = ''
92
+
93
+
94
+ def disable_torch_init():
95
+ """
96
+ Disable the redundant torch default initialization to accelerate model creation.
97
+ """
98
+ import torch
99
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
100
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
101
+
102
+
103
+ def violates_moderation(text):
104
+ """
105
+ Check whether the text violates OpenAI moderation API.
106
+ """
107
+ url = "https://api.openai.com/v1/moderations"
108
+ headers = {"Content-Type": "application/json",
109
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
110
+ text = text.replace("\n", "")
111
+ data = "{" + '"input": ' + f'"{text}"' + "}"
112
+ data = data.encode("utf-8")
113
+ try:
114
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
115
+ flagged = ret.json()["results"][0]["flagged"]
116
+ except requests.exceptions.RequestException as e:
117
+ flagged = False
118
+ except KeyError as e:
119
+ flagged = False
120
+
121
+ return flagged
122
+
123
+
124
+ def pretty_print_semaphore(semaphore):
125
+ if semaphore is None:
126
+ return "None"
127
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ torch==2.1.2
3
+ torchvision==0.16.2
4
+ transformers==4.40.0
5
+ sentencepiece==0.1.99
6
+ opencv-python
7
+ gradio