abreza commited on
Commit
ab8caaa
1 Parent(s): e9b58ad

add mgie-llava

Browse files
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
 
3
  from launch.image_generation import image_generation_ui
4
  from launch.model_generation import model_generation_ui
5
- from launch.story_generator import story_generation_ui
 
6
 
7
 
8
  with gr.Blocks() as demo:
@@ -10,10 +11,13 @@ with gr.Blocks() as demo:
10
  with gr.Tab("Generate Story"):
11
  story_generation_ui()
12
 
13
- with gr.Tab("Generate Image and Remove Background"):
14
- input_image, processed_image = image_generation_ui()
 
 
 
15
 
16
- with gr.Tab("Generate 3D Model"):
17
  output_model_obj, output_model_glb = model_generation_ui(
18
  processed_image)
19
 
 
2
 
3
  from launch.image_generation import image_generation_ui
4
  from launch.model_generation import model_generation_ui
5
+ from launch.story_generation import story_generation_ui
6
+ from launch.image_edition import image_edition_ui
7
 
8
 
9
  with gr.Blocks() as demo:
 
11
  with gr.Tab("Generate Story"):
12
  story_generation_ui()
13
 
14
+ with gr.Tab("2D Character and Assets"):
15
+ with gr.Tab("Edit Image"):
16
+ image_edition_ui()
17
+ with gr.Tab("Generate Image and Remove Background"):
18
+ input_image, processed_image = image_generation_ui()
19
 
20
+ with gr.Tab("3D Model"):
21
  output_model_obj, output_model_glb = model_generation_ui(
22
  processed_image)
23
 
examples/_input/0.jpg ADDED
examples/_input/1.jpg ADDED
examples/_input/10.jpg ADDED
examples/_input/11.jpg ADDED
examples/_input/12.jpg ADDED
examples/_input/13.jpg ADDED
examples/_input/14.jpg ADDED
examples/_input/15.jpg ADDED
examples/_input/16.jpg ADDED
examples/_input/17.jpg ADDED
examples/_input/18.jpg ADDED
examples/_input/19.jpg ADDED
examples/_input/2.jpg ADDED
examples/_input/3.jpg ADDED
examples/_input/4.jpg ADDED
examples/_input/5.jpg ADDED
examples/_input/6.jpg ADDED
examples/_input/7.jpg ADDED
examples/_input/8.jpg ADDED
examples/_input/9.jpg ADDED
launch/image_edition.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import diffusers
2
+ import transformers
3
+ import gradio as gr
4
+ from ml_mgie.mgie_llava import *
5
+ from ml_mgie.conversation import conv_templates
6
+ import torch as T
7
+ import numpy as np
8
+ from PIL import Image
9
+ import huggingface_hub
10
+ import spaces
11
+
12
+ # Constants
13
+ DEFAULT_IMAGE_TOKEN = '<image>'
14
+ DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
15
+ DEFAULT_IM_START_TOKEN = '<im_start>'
16
+ DEFAULT_IM_END_TOKEN = '<im_end>'
17
+ PATH_LLAVA = '_ckpt/LLaVA-7B-v1'
18
+
19
+ # Download the model checkpoint
20
+ huggingface_hub.snapshot_download(
21
+ repo_id='tsujuifu/ml-mgie', repo_type='model', local_dir='_ckpt', local_dir_use_symlinks=False)
22
+
23
+ # Load the model and tokenizer
24
+ tokenizer = transformers.AutoTokenizer.from_pretrained(PATH_LLAVA)
25
+ model = LlavaLlamaForCausalLM.from_pretrained(
26
+ PATH_LLAVA, low_cpu_mem_usage=True, torch_dtype=T.float16, use_cache=True).cuda()
27
+ image_processor = transformers.CLIPImageProcessor.from_pretrained(
28
+ model.config.mm_vision_tower, torch_dtype=T.float16)
29
+
30
+ # Configure the tokenizer and model
31
+ tokenizer.padding_side = 'left'
32
+ tokenizer.add_tokens(['[IMG0]', '[IMG1]', '[IMG2]', '[IMG3]',
33
+ '[IMG4]', '[IMG5]', '[IMG6]', '[IMG7]'], special_tokens=True)
34
+ model.resize_token_embeddings(len(tokenizer))
35
+ ckpt = T.load('_ckpt/mgie_7b/mllm.pt', map_location='cpu')
36
+ model.load_state_dict(ckpt, strict=False)
37
+
38
+ mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end', False)
39
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
40
+ if mm_use_im_start_end:
41
+ tokenizer.add_tokens(
42
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
43
+
44
+ vision_tower = model.get_model().vision_tower[0]
45
+ vision_tower = transformers.CLIPVisionModel.from_pretrained(
46
+ vision_tower.config._name_or_path, torch_dtype=T.float16, low_cpu_mem_usage=True).cuda()
47
+ model.get_model().vision_tower[0] = vision_tower
48
+ vision_config = vision_tower.config
49
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
50
+ [DEFAULT_IMAGE_PATCH_TOKEN])[0]
51
+ vision_config.use_im_start_end = mm_use_im_start_end
52
+ if mm_use_im_start_end:
53
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
54
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
55
+ image_token_len = (vision_config.image_size//vision_config.patch_size)**2
56
+
57
+ _ = model.eval()
58
+
59
+ # Load the diffusion pipeline
60
+ pipe = diffusers.StableDiffusionInstructPix2PixPipeline.from_pretrained(
61
+ 'timbrooks/instruct-pix2pix', torch_dtype=T.float16).to('cuda')
62
+ pipe.set_progress_bar_config(disable=True)
63
+ pipe.unet.load_state_dict(T.load('_ckpt/mgie_7b/unet.pt', map_location='cpu'))
64
+ print('--init MGIE--')
65
+
66
+
67
+ def crop_resize(f, sz=512):
68
+ w, h = f.size
69
+ if w > h:
70
+ p = (w-h)//2
71
+ f = f.crop([p, 0, p+h, h])
72
+ elif h > w:
73
+ p = (h-w)//2
74
+ f = f.crop([0, p, w, p+w])
75
+ f = f.resize([sz, sz])
76
+ return f
77
+
78
+
79
+ def remove_alter(s):
80
+ if 'ASSISTANT:' in s:
81
+ s = s[s.index('ASSISTANT:')+10:].strip()
82
+ if '</s>' in s:
83
+ s = s[:s.index('</s>')].strip()
84
+ if 'alternative' in s.lower():
85
+ s = s[:s.lower().index('alternative')]
86
+ if '[IMG0]' in s:
87
+ s = s[:s.index('[IMG0]')]
88
+ s = '.'.join([s.strip() for s in s.split('.')[:2]])
89
+ if s[-1] != '.':
90
+ s += '.'
91
+ return s.strip()
92
+
93
+ # Main MGIE function
94
+
95
+
96
+ @spaces.GPU(enable_queue=True)
97
+ def go_mgie(img, txt, seed, cfg_txt, cfg_img):
98
+ EMB = ckpt['emb'].cuda()
99
+ with T.inference_mode():
100
+ NULL = model.edit_head(T.zeros(1, 8, 4096).half().to('cuda'), EMB)
101
+
102
+ img, seed = crop_resize(Image.fromarray(img).convert('RGB')), int(seed)
103
+ inp = img
104
+
105
+ img = image_processor.preprocess(img, return_tensors='pt')[
106
+ 'pixel_values'][0]
107
+ txt = "what will this image be like if '%s'" % (txt)
108
+ txt = txt+'\n'+DEFAULT_IM_START_TOKEN + \
109
+ DEFAULT_IMAGE_PATCH_TOKEN*image_token_len+DEFAULT_IM_END_TOKEN
110
+ conv = conv_templates['vicuna_v1_1'].copy()
111
+ conv.append_message(conv.roles[0], txt), conv.append_message(
112
+ conv.roles[1], None)
113
+ txt = conv.get_prompt()
114
+ txt = tokenizer(txt)
115
+ txt, mask = T.as_tensor(txt['input_ids']), T.as_tensor(
116
+ txt['attention_mask'])
117
+
118
+ with T.inference_mode():
119
+ _ = model.cuda()
120
+ out = model.generate(txt.unsqueeze(dim=0).cuda(), images=img.half().unsqueeze(dim=0).cuda(), attention_mask=mask.unsqueeze(dim=0).cuda(),
121
+ do_sample=False, max_new_tokens=96, num_beams=1, no_repeat_ngram_size=3,
122
+ return_dict_in_generate=True, output_hidden_states=True)
123
+ out, hid = out['sequences'][0].tolist(), T.cat(
124
+ [x[-1] for x in out['hidden_states']], dim=1)[0]
125
+
126
+ if 32003 in out:
127
+ p = out.index(32003)-1
128
+ else:
129
+ p = len(hid)-9
130
+ p = min(p, len(hid)-9)
131
+ hid = hid[p:p+8]
132
+
133
+ out = remove_alter(tokenizer.decode(out))
134
+ _ = model.cuda()
135
+ emb = model.edit_head(hid.unsqueeze(dim=0), EMB)
136
+ res = pipe(image=inp, prompt_embeds=emb, negative_prompt_embeds=NULL,
137
+ generator=T.Generator(device='cuda').manual_seed(seed), guidance_scale=cfg_txt, image_guidance_scale=cfg_img).images[0]
138
+
139
+ return res, out
140
+
141
+ # Example function
142
+
143
+
144
+ def go_example(seed, cfg_txt, cfg_img):
145
+ ins = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
146
+ 'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
147
+ 'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
148
+ 'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
149
+ 'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
150
+ i = T.randint(len(ins), (1, )).item()
151
+
152
+ return './examples/_input/%d.jpg' % (i), ins[i], seed, cfg_txt, cfg_img
153
+
154
+
155
+ # Test MGIE
156
+ go_mgie(np.array(Image.open('./examples/_input/0.jpg').convert('RGB')),
157
+ 'make the frame red', 13331, 7.5, 1.5)
158
+ print('--init GO--')
159
+
160
+
161
+ def image_edition_ui():
162
+ with gr.Row():
163
+ inp, res = [gr.Image(height=384, width=384, label='Input Image', interactive=True),
164
+ gr.Image(height=384, width=384, label='Goal Image', interactive=True)]
165
+ with gr.Row():
166
+ txt, out = [gr.Textbox(label='Instruction', interactive=True),
167
+ gr.Textbox(label='Expressive Instruction', interactive=False)]
168
+ with gr.Row():
169
+ seed, cfg_txt, cfg_img = [gr.Number(value=13331, label='Seed', interactive=True),
170
+ gr.Number(
171
+ value=7.5, label='Text CFG', interactive=True),
172
+ gr.Number(value=1.5, label='Image CFG', interactive=True)]
173
+ with gr.Row():
174
+ btn_exp, btn_sub = [gr.Button('More Example'), gr.Button('Submit')]
175
+ btn_exp.click(fn=go_example, inputs=[seed, cfg_txt, cfg_img], outputs=[
176
+ inp, txt, seed, cfg_txt, cfg_img])
177
+ btn_sub.click(fn=go_mgie, inputs=[
178
+ inp, txt, seed, cfg_txt, cfg_img], outputs=[res, out])
179
+
180
+ ins = ['make the frame red', 'turn the day into night', 'give him a beard', 'make cottage a mansion',
181
+ 'remove yellow object from dogs paws', 'change the hair from red to blue', 'remove the text', 'increase the image contrast',
182
+ 'remove the people in the background', 'please make this photo professional looking', 'darken the image, sharpen it', 'photoshop the girl out',
183
+ 'make more brightness', 'take away the brown filter form the image', 'add more contrast to simulate more light', 'dark on rgb',
184
+ 'make the face happy', 'change view as ocean', 'replace basketball with soccer ball', 'let the floor be made of wood']
185
+ gr.Examples(examples=[['./examples/_input/%d.jpg' % (i), ins[i]]
186
+ for i in [1, 5, 8, 14, 16]], inputs=[inp, txt])
launch/model_generation.py CHANGED
@@ -143,30 +143,55 @@ def model_generation_ui(processed_image):
143
  with gr.Row():
144
  submit_mesh = gr.Button(
145
  "Generate 3D Model", elem_id="generate", variant="primary")
 
146
  with gr.Row():
147
  with gr.Column():
148
  mv_show_images = gr.Image(
149
- label="Generated Multi-views",
150
- type="pil",
151
- interactive=False
152
- )
153
  with gr.Column():
154
  with gr.Tab("OBJ"):
155
  output_model_obj = gr.Model3D(
156
- label="Output Model (OBJ Format)",
157
- interactive=False,
158
- )
159
  with gr.Tab("GLB"):
160
  output_model_glb = gr.Model3D(
161
- label="Output Model (GLB Format)",
162
- interactive=False,
163
- )
164
-
165
- mv_images = gr.State()
166
-
167
- submit_mesh.click(fn=generate_mvs, inputs=[processed_image], outputs=[mv_images, mv_show_images]).success(
168
- fn=make3d, inputs=[mv_images], outputs=[
169
- output_model_obj, output_model_glb]
170
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  return output_model_obj, output_model_glb
 
143
  with gr.Row():
144
  submit_mesh = gr.Button(
145
  "Generate 3D Model", elem_id="generate", variant="primary")
146
+
147
  with gr.Row():
148
  with gr.Column():
149
  mv_show_images = gr.Image(
150
+ label="Generated Multi-views", type="pil", interactive=False)
151
+
 
 
152
  with gr.Column():
153
  with gr.Tab("OBJ"):
154
  output_model_obj = gr.Model3D(
155
+ label="Output Model (OBJ Format)", interactive=False)
156
+
 
157
  with gr.Tab("GLB"):
158
  output_model_glb = gr.Model3D(
159
+ label="Output Model (GLB Format)", interactive=False)
160
+
161
+ mv_images = gr.State()
162
+
163
+ # Display a message if the processed image is empty
164
+ empty_image_message = gr.Markdown(
165
+ visible=False,
166
+ value="Please generate a 2D image before generating a 3D model."
167
+ )
168
+
169
+ def check_image(processed_image):
170
+ if processed_image is None:
171
+ return {
172
+ empty_image_message: gr.update(visible=True),
173
+ submit_mesh: gr.update(interactive=False)
174
+ }
175
+ else:
176
+ return {
177
+ empty_image_message: gr.update(visible=False),
178
+ submit_mesh: gr.update(interactive=True)
179
+ }
180
+
181
+ processed_image.change(
182
+ fn=check_image,
183
+ inputs=[processed_image],
184
+ outputs=[empty_image_message, submit_mesh]
185
+ )
186
+
187
+ submit_mesh.click(
188
+ fn=generate_mvs,
189
+ inputs=[processed_image],
190
+ outputs=[mv_images, mv_show_images]
191
+ ).success(
192
+ fn=make3d,
193
+ inputs=[mv_images],
194
+ outputs=[output_model_obj, output_model_glb]
195
+ )
196
 
197
  return output_model_obj, output_model_glb
launch/{story_generator.py → story_generation.py} RENAMED
File without changes
ml_mgie/__init__.py ADDED
File without changes
ml_mgie/conversation.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/conversation.py
3
+
4
+ import dataclasses
5
+ from enum import auto, Enum
6
+ from typing import List, Tuple
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ if self.sep_style == SeparatorStyle.SINGLE:
32
+ ret = self.system + self.sep
33
+ for role, message in self.messages:
34
+ if message:
35
+ if type(message) is tuple:
36
+ message, _, _ = message
37
+ ret += role + ": " + message + self.sep
38
+ else:
39
+ ret += role + ":"
40
+ return ret
41
+ elif self.sep_style == SeparatorStyle.TWO:
42
+ seps = [self.sep, self.sep2]
43
+ ret = self.system + seps[0]
44
+ for i, (role, message) in enumerate(self.messages):
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + seps[i % 2]
49
+ else:
50
+ ret += role + ":"
51
+ return ret
52
+ if self.sep_style == SeparatorStyle.MPT:
53
+ ret = self.system + self.sep
54
+ for role, message in self.messages:
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + message + self.sep
59
+ else:
60
+ ret += role
61
+ return ret
62
+ else:
63
+ raise ValueError(f"Invalid style: {self.sep_style}")
64
+
65
+ def append_message(self, role, message):
66
+ self.messages.append([role, message])
67
+
68
+ def get_images(self, return_pil=False):
69
+ images = []
70
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
71
+ if i % 2 == 0:
72
+ if type(msg) is tuple:
73
+ import base64
74
+ from io import BytesIO
75
+ from PIL import Image
76
+ msg, image, image_process_mode = msg
77
+ if image_process_mode == "Pad":
78
+ def expand2square(pil_img, background_color=(122, 116, 104)):
79
+ width, height = pil_img.size
80
+ if width == height:
81
+ return pil_img
82
+ elif width > height:
83
+ result = Image.new(pil_img.mode, (width, width), background_color)
84
+ result.paste(pil_img, (0, (width - height) // 2))
85
+ return result
86
+ else:
87
+ result = Image.new(pil_img.mode, (height, height), background_color)
88
+ result.paste(pil_img, ((height - width) // 2, 0))
89
+ return result
90
+ image = expand2square(image)
91
+ elif image_process_mode == "Crop":
92
+ pass
93
+ elif image_process_mode == "Resize":
94
+ image = image.resize((224, 224))
95
+ else:
96
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
97
+ max_hw, min_hw = max(image.size), min(image.size)
98
+ aspect_ratio = max_hw / min_hw
99
+ max_len, min_len = 800, 400
100
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
101
+ longest_edge = int(shortest_edge * aspect_ratio)
102
+ W, H = image.size
103
+ if H > W:
104
+ H, W = longest_edge, shortest_edge
105
+ else:
106
+ H, W = shortest_edge, longest_edge
107
+ image = image.resize((W, H))
108
+ if return_pil:
109
+ images.append(image)
110
+ else:
111
+ buffered = BytesIO()
112
+ image.save(buffered, format="JPEG")
113
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
114
+ images.append(img_b64_str)
115
+ return images
116
+
117
+ def to_gradio_chatbot(self):
118
+ ret = []
119
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
120
+ if i % 2 == 0:
121
+ if type(msg) is tuple:
122
+ import base64
123
+ from io import BytesIO
124
+ msg, image, image_process_mode = msg
125
+ max_hw, min_hw = max(image.size), min(image.size)
126
+ aspect_ratio = max_hw / min_hw
127
+ max_len, min_len = 800, 400
128
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
129
+ longest_edge = int(shortest_edge * aspect_ratio)
130
+ W, H = image.size
131
+ if H > W:
132
+ H, W = longest_edge, shortest_edge
133
+ else:
134
+ H, W = shortest_edge, longest_edge
135
+ image = image.resize((W, H))
136
+ # image = image.resize((224, 224))
137
+ buffered = BytesIO()
138
+ image.save(buffered, format="JPEG")
139
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
140
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
141
+ msg = msg.replace('<image>', img_str)
142
+ ret.append([msg, None])
143
+ else:
144
+ ret[-1][-1] = msg
145
+ return ret
146
+
147
+ def copy(self):
148
+ return Conversation(
149
+ system=self.system,
150
+ roles=self.roles,
151
+ messages=[[x, y] for x, y in self.messages],
152
+ offset=self.offset,
153
+ sep_style=self.sep_style,
154
+ sep=self.sep,
155
+ sep2=self.sep2)
156
+
157
+ def dict(self):
158
+ if len(self.get_images()) > 0:
159
+ return {
160
+ "system": self.system,
161
+ "roles": self.roles,
162
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
163
+ "offset": self.offset,
164
+ "sep": self.sep,
165
+ "sep2": self.sep2,
166
+ }
167
+ return {
168
+ "system": self.system,
169
+ "roles": self.roles,
170
+ "messages": self.messages,
171
+ "offset": self.offset,
172
+ "sep": self.sep,
173
+ "sep2": self.sep2,
174
+ }
175
+
176
+
177
+ conv_v1 = Conversation(
178
+ system="A chat between a curious human and an artificial intelligence assistant. "
179
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
180
+ roles=("Human", "Assistant"),
181
+ messages=(
182
+ ("Human", "Give three tips for staying healthy."),
183
+ ("Assistant",
184
+ "Sure, here are three tips for staying healthy:\n"
185
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
186
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
187
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
188
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
189
+ "activities at least two days per week.\n"
190
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
191
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
192
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
193
+ "and aim to drink plenty of water throughout the day.\n"
194
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
195
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
196
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
197
+ "help improve the quality of your sleep.")
198
+ ),
199
+ offset=2,
200
+ sep_style=SeparatorStyle.SINGLE,
201
+ sep="###",
202
+ )
203
+
204
+ conv_v1_2 = Conversation(
205
+ system="A chat between a curious human and an artificial intelligence assistant. "
206
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
207
+ roles=("Human", "Assistant"),
208
+ messages=(
209
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
210
+ ("Assistant",
211
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
212
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
213
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
214
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
215
+ "renewable and non-renewable energy sources:\n"
216
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
217
+ "energy sources are finite and will eventually run out.\n"
218
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
219
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
220
+ "and other negative effects.\n"
221
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
222
+ "have lower operational costs than non-renewable sources.\n"
223
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
224
+ "locations than non-renewable sources.\n"
225
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
226
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
227
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
228
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
229
+ ),
230
+ offset=2,
231
+ sep_style=SeparatorStyle.SINGLE,
232
+ sep="###",
233
+ )
234
+
235
+ conv_vicuna_v1_1 = Conversation(
236
+ system="A chat between a curious user and an artificial intelligence assistant. "
237
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
238
+ roles=("USER", "ASSISTANT"),
239
+ version="v1",
240
+ messages=(),
241
+ offset=0,
242
+ sep_style=SeparatorStyle.TWO,
243
+ sep=" ",
244
+ sep2="</s>",
245
+ )
246
+
247
+ conv_mpt = Conversation(
248
+ system="""<|im_start|>system
249
+ - You are a helpful language and vision assistant.
250
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
251
+ - You should follow the instructions carefully and explain your answers in detail.""",
252
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
253
+ version="mpt",
254
+ messages=(),
255
+ offset=0,
256
+ sep_style=SeparatorStyle.MPT,
257
+ sep="<|im_end|>",
258
+ )
259
+
260
+ conv_mpt_text = Conversation(
261
+ system="""<|im_start|>system
262
+ - You are a helpful assistant chatbot trained by MosaicML.
263
+ - You answer questions.
264
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
265
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
266
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
267
+ version="mpt",
268
+ messages=(),
269
+ offset=0,
270
+ sep_style=SeparatorStyle.MPT,
271
+ sep="<|im_end|>",
272
+ )
273
+
274
+ conv_bair_v1 = Conversation(
275
+ system="BEGINNING OF CONVERSATION:",
276
+ roles=("USER", "GPT"),
277
+ messages=(),
278
+ offset=0,
279
+ sep_style=SeparatorStyle.TWO,
280
+ sep=" ",
281
+ sep2="</s>",
282
+ )
283
+
284
+ simple_conv = Conversation(
285
+ system="A chat between a curious human and an artificial intelligence assistant. "
286
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
287
+ roles=("Human", "Assistant"),
288
+ messages=(
289
+ ("Human", "Hi!"),
290
+ ("Assistant", "Hi there! How can I help you today?")
291
+ ),
292
+ offset=2,
293
+ sep_style=SeparatorStyle.SINGLE,
294
+ sep="###",
295
+ )
296
+
297
+ simple_conv_multimodal = Conversation(
298
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
299
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
300
+ "Follow the instructions carefully and explain your answers in detail.",
301
+ roles=("Human", "Assistant"),
302
+ messages=(
303
+ ("Human", "Hi!"),
304
+ ("Assistant", "Hi there! How can I help you today?\n")
305
+ ),
306
+ offset=2,
307
+ sep_style=SeparatorStyle.SINGLE,
308
+ sep="###",
309
+ )
310
+
311
+ simple_conv_mpt_multimodal = Conversation(
312
+ system="""<|im_start|>system
313
+ - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
314
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
315
+ - You should follow the instructions carefully and explain your answers in detail.""",
316
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
+ version="mpt",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.MPT,
321
+ sep="<|im_end|>",
322
+ )
323
+
324
+ simple_conv_legacy = Conversation(
325
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
326
+ "You are designed to assist human with a variety of tasks using natural language."
327
+ "Follow the instructions carefully.",
328
+ roles=("Human", "Assistant"),
329
+ messages=(
330
+ ("Human", "Hi!\n\n### Response:"),
331
+ ("Assistant", "Hi there! How can I help you today?\n")
332
+ ),
333
+ offset=2,
334
+ sep_style=SeparatorStyle.SINGLE,
335
+ sep="###",
336
+ )
337
+
338
+ conv_llava_v1 = Conversation(
339
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
340
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
341
+ "Follow the instructions carefully and explain your answers in detail.",
342
+ roles=("USER", "ASSISTANT"),
343
+ version="v1",
344
+ messages=(),
345
+ offset=0,
346
+ sep_style=SeparatorStyle.TWO,
347
+ sep=" ",
348
+ sep2="</s>",
349
+ )
350
+
351
+ default_conversation = conv_v1_2
352
+ conv_templates = {
353
+ "default": conv_v1_2,
354
+ "simple": simple_conv,
355
+ "simple_legacy": simple_conv_legacy,
356
+ "multimodal": simple_conv_multimodal,
357
+ "mpt_multimodal": simple_conv_mpt_multimodal,
358
+ "llava_v1": conv_llava_v1,
359
+
360
+ # fastchat
361
+ "v1": conv_v1_2,
362
+ "bair_v1": conv_bair_v1,
363
+ "vicuna_v1_1": conv_vicuna_v1_1,
364
+ "mpt": conv_mpt,
365
+ "mpt_text": conv_mpt_text,
366
+ }
367
+
368
+
369
+ if __name__ == "__main__":
370
+ print(default_conversation.get_prompt())
ml_mgie/mgie_llava.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # For licensing see accompanying LICENSE file.
3
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
4
+ #
5
+ # modified from https://github.com/haotian-liu/LLaVA/blob/7ace501183c4bdec6052ec1a30039cdc3242a67c/llava/model/llava.py
6
+
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.nn import CrossEntropyLoss
13
+
14
+ from transformers import AutoConfig, AutoModelForCausalLM, \
15
+ LlamaConfig, LlamaModel, LlamaForCausalLM, \
16
+ CLIPVisionModel, CLIPImageProcessor
17
+
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+
20
+ import os, diffusers
21
+
22
+ DEFAULT_IMAGE_TOKEN = "<image>"
23
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
24
+ DEFAULT_IM_START_TOKEN = "<im_start>"
25
+ DEFAULT_IM_END_TOKEN = "<im_end>"
26
+
27
+
28
+ class LlavaConfig(LlamaConfig):
29
+ model_type = "llava"
30
+
31
+
32
+ class LlavaLlamaModel(LlamaModel):
33
+ config_class = LlavaConfig
34
+
35
+ def __init__(self, config: LlamaConfig):
36
+ super(LlavaLlamaModel, self).__init__(config)
37
+
38
+ if hasattr(config, "mm_vision_tower"):
39
+ # HACK: for FSDP
40
+ self.vision_tower = [CLIPVisionModel.from_pretrained(config.mm_vision_tower)]
41
+ # self.vision_tower = CLIPVisionModel.from_pretrained(config.mm_vision_tower)
42
+
43
+ if hasattr(config, "use_mm_proj"):
44
+ self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size)
45
+
46
+ def get_vision_tower(self):
47
+ vision_tower = getattr(self, 'vision_tower', None)
48
+ if type(vision_tower) is list:
49
+ vision_tower = vision_tower[0]
50
+ return vision_tower
51
+
52
+ def initialize_vision_modules(self, vision_tower, mm_vision_select_layer,
53
+ pretrain_mm_mlp_adapter=None, fsdp=None):
54
+ self.config.mm_vision_tower = vision_tower
55
+
56
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
57
+
58
+ if not hasattr(self, 'vision_tower'):
59
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower)
60
+ else:
61
+ vision_tower = self.vision_tower[0]
62
+ vision_tower.requires_grad_(False)
63
+
64
+ if fsdp is not None and len(fsdp) > 0:
65
+ self.vision_tower = [vision_tower]
66
+ else:
67
+ self.vision_tower = vision_tower
68
+
69
+ vision_config = vision_tower.config
70
+ num_patches = (vision_config.image_size // vision_config.patch_size) ** 2
71
+
72
+ self.config.use_mm_proj = True
73
+ self.config.mm_hidden_size = vision_config.hidden_size
74
+ self.config.mm_vision_select_layer = mm_vision_select_layer
75
+
76
+ if not hasattr(self, 'mm_projector'):
77
+ self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size)
78
+
79
+ if pretrain_mm_mlp_adapter is not None:
80
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
81
+ self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()})
82
+
83
+ return dict(
84
+ image_processor=image_processor,
85
+ image_token_len=num_patches,
86
+ vision_config=vision_config
87
+ )
88
+
89
+ def forward(
90
+ self,
91
+ input_ids: torch.LongTensor = None,
92
+ attention_mask: Optional[torch.Tensor] = None,
93
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
94
+ inputs_embeds: Optional[torch.FloatTensor] = None,
95
+ use_cache: Optional[bool] = None,
96
+ output_attentions: Optional[bool] = None,
97
+ output_hidden_states: Optional[bool] = None,
98
+ images: Optional[torch.FloatTensor] = None,
99
+ return_dict: Optional[bool] = None,
100
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
101
+
102
+ # HACK: replace back original embeddings for LLaVA pretraining
103
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
104
+ # if orig_embeds_params is not None:
105
+ # orig_embeds_params = orig_embeds_params[0]
106
+ # with torch.no_grad():
107
+ # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data
108
+
109
+ if inputs_embeds is None:
110
+ inputs_embeds = self.embed_tokens(input_ids)
111
+
112
+ vision_tower = self.get_vision_tower()
113
+ if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
114
+ # TODO: this is a modified multimodal LLM -- Haotian Liu
115
+ with torch.no_grad():
116
+ if type(images) is list:
117
+ # variable length images
118
+ image_features = []
119
+ for image in images:
120
+ image_forward_out = vision_tower(image.unsqueeze(0), output_hidden_states=True)
121
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
122
+ select_hidden_state = image_forward_out.hidden_states[select_hidden_state_layer]
123
+ image_feature = select_hidden_state[:, 1:]
124
+ image_features.append(image_feature)
125
+ else:
126
+ image_forward_outs = vision_tower(images.to(vision_tower.dtype), output_hidden_states=True)
127
+ select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1)
128
+ select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
129
+ image_features = select_hidden_state[:, 1:].to(images.dtype)
130
+ if type(images) is list:
131
+ image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features]
132
+ else:
133
+ image_features = self.mm_projector(image_features)
134
+ dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
135
+ dummy_image_features = self.mm_projector(dummy_image_features)
136
+
137
+ new_input_embeds = []
138
+ cur_image_idx = 0
139
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
140
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0:
141
+ # multimodal LLM, but the current sample is not multimodal
142
+ cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
143
+ new_input_embeds.append(cur_input_embeds)
144
+ cur_image_idx += 1
145
+ continue
146
+ if vision_tower.config.use_im_start_end:
147
+ cur_image_features = image_features[cur_image_idx]
148
+ num_patches = cur_image_features.shape[0]
149
+ if (cur_input_ids == vision_tower.config.im_start_token).sum() != (cur_input_ids == vision_tower.config.im_end_token).sum():
150
+ raise ValueError("The number of image start tokens and image end tokens should be the same.")
151
+ image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0]
152
+ for image_start_token_pos in image_start_tokens:
153
+ cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device)
154
+ num_patches = cur_image_features.shape[0]
155
+ if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token:
156
+ raise ValueError("The image end token should follow the image start token.")
157
+ if orig_embeds_params is not None:
158
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
159
+ else:
160
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
161
+ cur_image_idx += 1
162
+ new_input_embeds.append(cur_new_input_embeds)
163
+ else:
164
+ cur_image_features = image_features[cur_image_idx]
165
+ num_patches = cur_image_features.shape[0]
166
+ if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches:
167
+ raise ValueError("The number of image patch tokens should be the same as the number of image patches.")
168
+ masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0]
169
+ mask_index_start = masked_indices[0]
170
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
171
+ raise ValueError("The image patch tokens should be consecutive.")
172
+ if orig_embeds_params is not None:
173
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
174
+ else:
175
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
176
+ new_input_embeds.append(cur_new_input_embeds)
177
+ cur_image_idx += 1
178
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
179
+
180
+ return super(LlavaLlamaModel, self).forward(
181
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
182
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
183
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict
185
+ )
186
+
187
+ class EditMapper(nn.Module):
188
+ def __init__(self):
189
+ super().__init__()
190
+
191
+ self.llm2hid = nn.Linear(4096, 512)
192
+ self.query = nn.Parameter(torch.randn(1, 77, 512))
193
+ self.mapper = nn.Transformer(batch_first=True, norm_first=True,
194
+ d_model=512, nhead=4, num_encoder_layers=4, num_decoder_layers=4,
195
+ dim_feedforward=2048, dropout=0.0)
196
+ self.hid2feat = nn.Linear(512, 768)
197
+
198
+ def forward(self, llm, emb):
199
+ hid = self.llm2hid(llm+emb)
200
+ hid = self.mapper(hid, self.query.repeat(llm.shape[0], 1, 1))
201
+ feat = self.hid2feat(hid)
202
+
203
+ return feat
204
+
205
+ class LlavaLlamaForCausalLM(LlamaForCausalLM):
206
+ config_class = LlavaConfig
207
+
208
+ def __init__(self, config):
209
+ super(LlamaForCausalLM, self).__init__(config)
210
+ self.model = LlavaLlamaModel(config)
211
+
212
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
213
+
214
+ self.edit_head = EditMapper()
215
+
216
+ '''self.scheduler, self.vae, self.unet = [diffusers.DDPMScheduler.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='scheduler'),
217
+ diffusers.AutoencoderKL.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='vae'),
218
+ diffusers.UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder='unet')]
219
+ self.vae.requires_grad_(False)
220
+ self.unet.register_to_config(in_channels=8)
221
+ with torch.no_grad():
222
+ conv = torch.nn.Conv2d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding)
223
+ conv.weight.zero_()
224
+ conv.weight[:, :4, :, :].copy_(self.unet.conv_in.weight)
225
+ self.unet.conv_in = conv'''
226
+
227
+ # Initialize weights and apply final processing
228
+ self.post_init()
229
+
230
+ def get_model(self):
231
+ return self.model
232
+
233
+ def get_vision_tower(self):
234
+ return self.get_model().get_vision_tower()
235
+
236
+ def get_vision_tower(self):
237
+ model = self.get_model()
238
+ vision_tower = model.vision_tower
239
+ if type(vision_tower) is list:
240
+ vision_tower = vision_tower[0]
241
+ return vision_tower
242
+
243
+ def forward(
244
+ self,
245
+ input_ids: torch.LongTensor = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
248
+ inputs_embeds: Optional[torch.FloatTensor] = None,
249
+ labels: Optional[torch.LongTensor] = None,
250
+ use_cache: Optional[bool] = None,
251
+ output_attentions: Optional[bool] = None,
252
+ output_hidden_states: Optional[bool] = None,
253
+ images: Optional[torch.FloatTensor] = None,
254
+ return_dict: Optional[bool] = None,
255
+ p2p_inp=None, p2p_ans=None
256
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
257
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
258
+ output_hidden_states = (
259
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
260
+ )
261
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
262
+
263
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
264
+ outputs = self.model(
265
+ input_ids=input_ids,
266
+ attention_mask=attention_mask,
267
+ past_key_values=past_key_values,
268
+ inputs_embeds=inputs_embeds,
269
+ use_cache=use_cache,
270
+ output_attentions=output_attentions,
271
+ output_hidden_states=output_hidden_states,
272
+ return_dict=return_dict,
273
+ images=images
274
+ )
275
+
276
+ hidden_states = outputs[0]
277
+ logits = self.lm_head(hidden_states)
278
+
279
+ loss = None
280
+ if labels is not None:
281
+ # Shift so that tokens < n predict n
282
+ shift_logits = logits[..., :-1, :].contiguous()
283
+ shift_labels = labels[..., 1:].contiguous()
284
+ # Flatten the tokens
285
+ loss_fct = CrossEntropyLoss()
286
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
287
+ shift_labels = shift_labels.view(-1)
288
+ # Enable model/pipeline parallelism
289
+ shift_labels = shift_labels.to(shift_logits.device)
290
+ loss = loss_fct(shift_logits, shift_labels)
291
+
292
+ if labels is not None:
293
+ llm = []
294
+ for i in range(labels.shape[0]):
295
+ try: p = labels[i].data.cpu().tolist().index(32003)-1
296
+ except: p = len(labels[i])-9
297
+ p = min(len(hidden_states[i])-9, p)
298
+ llm.append(hidden_states[i][p:p+8].unsqueeze(0))
299
+ llm = torch.cat(llm, dim=0)
300
+ hid_edit = self.edit_head(llm, self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
301
+
302
+ B, DROP = labels.shape[0], 0.05
303
+
304
+ hid_null = self.edit_head(torch.zeros(B, 8, 4096, device=labels.device),
305
+ self.model.embed_tokens.weight[-8:].unsqueeze(dim=0).repeat(labels.shape[0], 1, 1))
306
+
307
+ with torch.no_grad():
308
+ lat_ans, lat_inp = self.vae.encode(p2p_ans).latent_dist.sample()*self.vae.config.scaling_factor, self.vae.encode(p2p_inp).latent_dist.mode()
309
+ lat_ans, lat_inp = [torch.from_numpy(lat_ans.data.cpu().float().numpy()).to(lat_ans.device),
310
+ torch.from_numpy(lat_inp.data.cpu().float().numpy()).to(lat_inp.device)]
311
+
312
+ noise = torch.randn_like(lat_ans)
313
+ ts = torch.randint(0, self.scheduler.config.num_train_timesteps, (B, ), device=noise.device).long()
314
+ lat_noise = self.scheduler.add_noise(lat_ans, noise, ts)
315
+
316
+ prob = torch.rand(B, device=lat_ans.device)
317
+ mask = (prob<(DROP*2)).reshape(B, 1, 1)
318
+ hid_edit = torch.where(mask, hid_null, hid_edit)
319
+ mask = (1.0-((prob>=DROP).to(lat_inp.dtype)*(prob<(DROP*3)).to(lat_inp.dtype))).reshape(B, 1, 1, 1)
320
+ lat_inp *= mask
321
+
322
+ out = self.unet(torch.cat([lat_noise, lat_inp], dim=1), ts, hid_edit).sample
323
+
324
+ loss_ce, loss_edit = loss, nn.functional.mse_loss(out, noise, reduction='mean')
325
+ if int(os.environ['LOCAL_RANK'])==0: print('loss_ce:', loss_ce, '/', 'loss_edit:', loss_edit)
326
+ loss = loss_ce+loss_edit*0.5
327
+
328
+ if not return_dict:
329
+ output = (logits,) + outputs[1:]
330
+ return (loss,) + output if loss is not None else output
331
+
332
+ return CausalLMOutputWithPast(
333
+ loss=loss,
334
+ logits=logits,
335
+ past_key_values=outputs.past_key_values,
336
+ hidden_states=outputs.hidden_states,
337
+ attentions=outputs.attentions,
338
+ )
339
+
340
+ def prepare_inputs_for_generation(
341
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
342
+ ):
343
+ if past_key_values:
344
+ input_ids = input_ids[:, -1:]
345
+
346
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
347
+ if inputs_embeds is not None and past_key_values is None:
348
+ model_inputs = {"inputs_embeds": inputs_embeds}
349
+ else:
350
+ model_inputs = {"input_ids": input_ids}
351
+
352
+ model_inputs.update(
353
+ {
354
+ "past_key_values": past_key_values,
355
+ "use_cache": kwargs.get("use_cache"),
356
+ "attention_mask": attention_mask,
357
+ "images": kwargs.get("images", None),
358
+ }
359
+ )
360
+ return model_inputs
361
+
362
+ def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
363
+ tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None):
364
+ vision_config = self.get_vision_tower().config
365
+ vision_config.use_im_start_end = mm_use_im_start_end
366
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
367
+ self.resize_token_embeddings(len(tokenizer))
368
+
369
+ if mm_use_im_start_end:
370
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
371
+ self.resize_token_embeddings(len(tokenizer))
372
+ vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
373
+
374
+ if num_new_tokens > 0:
375
+ input_embeddings = self.get_input_embeddings().weight.data
376
+ output_embeddings = self.get_output_embeddings().weight.data
377
+
378
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
379
+ dim=0, keepdim=True)
380
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
381
+ dim=0, keepdim=True)
382
+
383
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
384
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
385
+
386
+ if tune_mm_mlp_adapter:
387
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)]
388
+ for p in self.get_input_embeddings().parameters():
389
+ p.requires_grad = True
390
+ for p in self.get_output_embeddings().parameters():
391
+ p.requires_grad = False
392
+
393
+ if pretrain_mm_mlp_adapter:
394
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
395
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
396
+ assert num_new_tokens == 2
397
+ if input_embeddings.shape == embed_tokens_weight.shape:
398
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
399
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
400
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
401
+ else:
402
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
403
+
404
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
405
+
406
+ AutoConfig.register("llava", LlavaConfig)
407
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
requirements.txt CHANGED
@@ -27,3 +27,9 @@ flask
27
  pillow==9.5.0
28
  safetensors
29
  peft
 
 
 
 
 
 
 
27
  pillow==9.5.0
28
  safetensors
29
  peft
30
+
31
+
32
+ sentencepiece
33
+ tokenizers==0.12.1
34
+ datasets
35
+ evaluate