Kunpeng Song commited on
Commit
73c6f92
β€’
1 Parent(s): 7d79282
.DS_Store ADDED
Binary file (8.2 kB). View file
 
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: MoMA ZeroGPU
3
- emoji: πŸ“‰
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.32.1
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: MoMA
3
+ emoji: 🌍
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.31.4
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: Multi-modal LLM for image personalization
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from torchvision import transforms
6
+ import torch
7
+ from pytorch_lightning import seed_everything
8
+ from torchvision.utils import save_image
9
+ from model_lib.modules import MoMA_main_modal
10
+ from model_lib.utils import parse_args
11
+ import os
12
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
13
+
14
+ title = "MoMA"
15
+ description = "This model has to run on GPU. By default, we load the model with 4-bit quantization to make it fit in smaller hardwares."
16
+
17
+ def MoMA_demo(rgb, subject, prompt, strength, seed):
18
+ seed = int(seed) if seed else 0
19
+ try:
20
+ seed = int(seed)
21
+ except ValueError:
22
+ seed = 0
23
+ seed = seed if not seed == 0 else np.random.randint(0,1000)
24
+ print(f"Seed: {seed}")
25
+
26
+ with torch.no_grad():
27
+ generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
28
+ return generated_image
29
+
30
+ def inference(rgb, subject, prompt, strength, seed):
31
+ result = MoMA_demo(rgb, subject, prompt, strength, seed)
32
+ return result
33
+
34
+ seed_everything(0)
35
+ args = parse_args()
36
+ #load MoMA from HuggingFace. Auto download
37
+ model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
38
+
39
+ gr.Interface(
40
+ inference,
41
+ [gr.Image(type="pil", label="Input RGB"),
42
+ gr.Textbox(lines=1, label="subject"),
43
+ gr.Textbox(lines=1, label="Prompt"),
44
+ gr.Slider(minimum=0.2, maximum=1.2, step=0.1,label="Strength. Recommend: 1.0 for context editing; 0.4 for texture editing",value=1.0),
45
+ gr.Textbox(lines=1, label="Seed. Use 0 for a random seed")],
46
+ gr.Image(type="pil", label="Output"),
47
+ title=title,
48
+ description=description,
49
+ examples=[["example_images/newImages/3.jpg",'car','A car in autumn with falling leaves.',1.0,"6"],["example_images/newImages/3.jpg",'car','A wooden sculpture of a car on a table.',0.4,"4"],["example_images/newImages/2.jpg",'car','A car on a city road with green trees and buildings.',1.0,"4"],["example_images/newImages/03.jpg",'cat','A cat at the Grand Canyon.',1.0,"2"],["example_images/newImages/02.jpg",'dog','A dog in a spring garden with flowers.',1.0,"6"],["example_images/newImages/1.jpeg",'bird','A bird in spring with flowers.',1.0,"1"],["example_images/newImages/17.jpg",'robot','A robot in autumn mountain and lake.',1,"5"]],
50
+ allow_flagging='never'
51
+ ).launch(debug=False)
backupApp_version1.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from torchvision import transforms
6
+ import torch
7
+ from pytorch_lightning import seed_everything
8
+ from torchvision.utils import save_image
9
+ from model_lib.modules import MoMA_main_modal
10
+ from model_lib.utils import parse_args
11
+ import os
12
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
13
+
14
+ title = "MoMA"
15
+ description = "This model has to run on GPU"
16
+ article = "<p style='text-align: center'><a href='https://news.machinelearning.sg/posts/beautiful_profile_pics_remove_background_image_with_deeplabv3/'>Blog</a> | <a href='https://github.com/eugenesiow/practical-ml'>Github Repo</a></p>"
17
+
18
+ def MoMA_demo(rgb, mask, subject, prompt):
19
+ # move the input and model to GPU for speed if available
20
+ with torch.no_grad():
21
+ generated_image = model.generate_images(rgb, mask, subject, prompt, strength=1.0, seed=2)
22
+ return generated_image
23
+
24
+ def inference(rgb, mask, subject, prompt):
25
+ result = MoMA_demo(rgb, mask, subject, prompt)
26
+ return result
27
+
28
+ seed_everything(0)
29
+ args = parse_args()
30
+ #load MoMA from HuggingFace. Auto download
31
+ model = MoMA_main_modal(args).to(args.device, dtype=torch.float16)
32
+
33
+
34
+ ################ change texture ##################
35
+ # prompt = "A wooden sculpture of a car on the table."
36
+ # generated_image = model.generate_images(rgb_path, mask_path, subject, prompt, strength=0.4, seed=4, return_mask=True) # set strength to 0.4 for better prompt fidelity
37
+ # save_image(generated_image,f"{args.output_path}/{subject}_{prompt}.jpg")
38
+
39
+
40
+ gr.Interface(
41
+ inference,
42
+ [gr.Image(type="pil", label="Input RGB"),
43
+ gr.Image(type="pil", label="Input Mask"),
44
+ gr.Textbox(lines=1, label="subject"),
45
+ gr.Textbox(lines=5, label="Prompt")],
46
+ gr.Image(type="pil", label="Output"),
47
+ title=title,
48
+ description=description,
49
+ article=article,
50
+ examples=[["example_images/newImages/3.jpg",'example_images/newImages/3_mask.jpg','car','A car in autumn with falling leaves.']],
51
+ # enable_queue=True
52
+ ).launch(debug=False)
checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
checkpoints/attn_adapters_projectors.th ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../../../../../../home/ks1418/.cache/huggingface/hub/models--KunpengSong--MoMA_llava_7b/blobs/0b432a39e46f01cd9cdb4794b8ef13b9bb0aff2ad6da6800d67fd2ca4af21fa6
checkpoints/ckpt_saving_path.txt ADDED
File without changes
dataset_lib/__pycache__/dataset_eval_MoMA.cpython-310.pyc ADDED
Binary file (1.76 kB). View file
 
dataset_lib/dataset_eval_MoMA.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ from torchvision import transforms
5
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
6
+ from rembg import remove
7
+
8
+ def create_binary_mask(image):
9
+ grayscale = image.convert("L")
10
+ mask = grayscale.point(lambda x: 255 if x > 1 else 0, '1')
11
+ return mask
12
+
13
+ def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal):
14
+
15
+ LLaVa_processor = moMA_main_modal.image_processor_llava
16
+ llava_config = moMA_main_modal.model_llava.config
17
+
18
+ transform = transforms.Compose([
19
+ transforms.Resize((512, 512)),
20
+ ])
21
+
22
+ mask_pil = create_binary_mask(remove(image_pil)) # Image.open(mask_path)
23
+ blip2_opt = prompt
24
+
25
+ if transform is not None:
26
+ image_pil = transform(image_pil)
27
+ mask_pil = transform(mask_pil)
28
+
29
+ mask_pil = np.array(mask_pil)
30
+ mask_pil = mask_pil[:,:,0] if len(mask_pil.shape)==3 else mask_pil
31
+ image = torch.from_numpy(np.array(image_pil)).permute(2,0,1)
32
+ mask = (torch.clamp((torch.from_numpy(mask_pil).unsqueeze(0)).float(),min=0.0,max=1.0)>0).float()
33
+
34
+ res = {'image': (image/127.5-1).unsqueeze(0),\
35
+ 'mask': mask.unsqueeze(0), \
36
+ 'text': [blip2_opt]}
37
+
38
+ image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
39
+ image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
40
+
41
+ res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config)
42
+ res['label'] = [subject]
43
+ return res
44
+
example_images/newImages/.DS_Store ADDED
Binary file (6.15 kB). View file
 
example_images/newImages/02.jpg ADDED
example_images/newImages/03.jpg ADDED
example_images/newImages/1.jpeg ADDED
example_images/newImages/17.jpg ADDED
example_images/newImages/2.jpg ADDED
example_images/newImages/3.jpg ADDED
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Input RGB,subject,Prompt,Strength. Recommend: 1.0 for context editing; 0.4 for texture editing,Seed. Use 0 for a random seed,Output,flag,username,timestamp
2
+ ,,,1,,,,,2024-05-21 19:36:27.802622
model_lib/__init__.py ADDED
File without changes
model_lib/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (191 Bytes). View file
 
model_lib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (189 Bytes). View file
 
model_lib/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (7.06 kB). View file
 
model_lib/__pycache__/moMA_generator.cpython-310.pyc ADDED
Binary file (10 kB). View file
 
model_lib/__pycache__/moMA_generator.cpython-39.pyc ADDED
Binary file (10 kB). View file
 
model_lib/__pycache__/modules.cpython-310.pyc ADDED
Binary file (6.96 kB). View file
 
model_lib/__pycache__/modules.cpython-39.pyc ADDED
Binary file (6.94 kB). View file
 
model_lib/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.45 kB). View file
 
model_lib/attention_processor.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ import math
7
+ from torchvision.utils import save_image
8
+ import torchvision.transforms as T
9
+
10
+ def get_mask_from_cross(attn_processors):
11
+ reference_masks = []
12
+ for attn_processor in attn_processors.values():
13
+ if isinstance(attn_processor, IPAttnProcessor):
14
+ reference_masks.append(attn_processor.mask_i)
15
+ mask = torch.cat(reference_masks,dim=1).mean(dim=1)
16
+ mask = (mask-mask.min())/(mask.max()-mask.min())
17
+ mask = (mask>0.2).to(torch.float32)*mask
18
+ mask = (mask-mask.min())/(mask.max()-mask.min())
19
+ return mask.unsqueeze(1)
20
+
21
+ class IPAttnProcessor(nn.Module):
22
+ r"""
23
+ Attention processor for IP-Adapater.
24
+ Args:
25
+ hidden_size (`int`):
26
+ The hidden size of the attention layer.
27
+ cross_attention_dim (`int`):
28
+ The number of channels in the `encoder_hidden_states`.
29
+ scale (`float`, defaults to 1.0):
30
+ the weight scale of image prompt.
31
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
32
+ The context length of the image features.
33
+ """
34
+
35
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
36
+ super().__init__()
37
+
38
+ self.hidden_size = hidden_size
39
+ self.cross_attention_dim = cross_attention_dim
40
+ self.scale = scale
41
+ self.num_tokens = num_tokens
42
+
43
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
44
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
45
+
46
+ self.store_attn = None
47
+ self.enabled = True
48
+ self.mode = 'inject'
49
+
50
+ self.subject_idxs = None
51
+ self.mask_i = None
52
+ self.mask_ig_prev = None
53
+
54
+ def __call__(
55
+ self,
56
+ attn,
57
+ hidden_states,
58
+ encoder_hidden_states=None,
59
+ attention_mask=None,
60
+ temb=None,
61
+ ):
62
+ residual = hidden_states
63
+
64
+ input_ndim = hidden_states.ndim
65
+
66
+ if input_ndim == 4:
67
+ batch_size, channel, height, width = hidden_states.shape
68
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
69
+
70
+ batch_size, sequence_length, _ = (
71
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
72
+ )
73
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
74
+
75
+ if attn.group_norm is not None:
76
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
77
+
78
+ query = attn.to_q(hidden_states)
79
+
80
+ if encoder_hidden_states is None:
81
+ encoder_hidden_states = hidden_states
82
+ else:
83
+ # get encoder_hidden_states, ip_hidden_states
84
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
85
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
86
+ if attn.norm_cross:
87
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
88
+
89
+ key = attn.to_k(encoder_hidden_states)
90
+ value = attn.to_v(encoder_hidden_states)
91
+
92
+ query = attn.head_to_batch_dim(query)
93
+ key = attn.head_to_batch_dim(key)
94
+ value = attn.head_to_batch_dim(value)
95
+
96
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
97
+ hidden_states = torch.bmm(attention_probs, value)
98
+ hidden_states = attn.batch_to_head_dim(hidden_states)
99
+
100
+ # for ip-adapter
101
+ if self.enabled:
102
+ if self.mode == 'inject' or self.mode == 'masked_generation':
103
+ ip_key = self.to_k_ip(ip_hidden_states.to(torch.float16))
104
+ ip_value = self.to_v_ip(ip_hidden_states.to(torch.float16))
105
+ ip_key = attn.head_to_batch_dim(ip_key)
106
+ ip_value = attn.head_to_batch_dim(ip_value)
107
+ ip_attention_probs = attn.get_attention_scores(query, ip_key.to(torch.float32), None)
108
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value.to(torch.float32))
109
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
110
+ if (self.mask_ig_prev is not None) and self.mode == 'masked_generation':
111
+ mask_ig_prev = rearrange(F.interpolate(self.mask_ig_prev,size=int(math.sqrt(query.shape[1]))),"b c h w -> b (h w) c")
112
+ if not mask_ig_prev.shape[0]==ip_hidden_states.shape[0]: mask_ig_prev = mask_ig_prev.repeat(2,1,1)
113
+ ip_hidden_states = ip_hidden_states * mask_ig_prev
114
+ hidden_states = hidden_states + self.scale * ip_hidden_states
115
+ if self.mode == 'extract' or self.mode == 'masked_generation':
116
+ subject_idxs = self.subject_idxs*2 if not (hidden_states.shape[0] == len(self.subject_idxs)) else self.subject_idxs
117
+ assert (hidden_states.shape[0] == len(subject_idxs))
118
+ attentions = rearrange(attention_probs, '(b h) n d -> b h n d', h=8).mean(1)
119
+ attn_extracted = [attentions[i, :, subject_idxs[i]].sum(-1) for i in range(hidden_states.shape[0])]
120
+ attn_extracted = [(atn-atn.min())/(atn.max()-atn.min()) for atn in attn_extracted]
121
+ attn_extracted = torch.stack(attn_extracted, dim=0)
122
+ attn_extracted = rearrange(attn_extracted, 'b (h w) -> b h w', h=int(math.sqrt(attention_probs.shape[1])))
123
+ attn_extracted = torch.clamp(F.interpolate(attn_extracted.unsqueeze(1),size=512),min=0,max=1)
124
+ self.mask_i = attn_extracted
125
+
126
+ # linear proj
127
+ hidden_states = attn.to_out[0](hidden_states)
128
+ # dropout
129
+ hidden_states = attn.to_out[1](hidden_states)
130
+
131
+ if input_ndim == 4:
132
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
133
+
134
+ return hidden_states
135
+
136
+ ### added for self attention
137
+ class IPAttnProcessor_Self(nn.Module):
138
+ r"""
139
+ Attention processor for IP-Adapater. (But for self attention)
140
+ Args:
141
+ hidden_size (`int`):
142
+ The hidden size of the attention layer.
143
+ cross_attention_dim (`int`):
144
+ The number of channels in the `encoder_hidden_states`.
145
+ scale (`float`, defaults to 1.0):
146
+ the weight scale of image prompt.
147
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
148
+ The context length of the image features.
149
+ """
150
+
151
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
152
+ super().__init__()
153
+
154
+ self.hidden_size = hidden_size
155
+ self.cross_attention_dim = cross_attention_dim
156
+ self.scale = scale
157
+ self.num_tokens = num_tokens
158
+
159
+ self.to_k_ip = nn.Linear(hidden_size, hidden_size, bias=False)
160
+ self.to_v_ip = nn.Linear(hidden_size, hidden_size, bias=False)
161
+
162
+ self.scale_learnable = torch.nn.Parameter(torch.zeros(1),requires_grad=True)
163
+
164
+ self.enabled = True
165
+ self.mode = 'extract'
166
+
167
+ self.store_ks, self.store_vs = [], []
168
+ self.mask_id, self.mask_ig = None, None
169
+
170
+ def __call__(
171
+ self,
172
+ attn,
173
+ hidden_states,
174
+ encoder_hidden_states=None,
175
+ attention_mask=None,
176
+ temb=None,
177
+ ):
178
+ input_ndim = hidden_states.ndim
179
+
180
+ if input_ndim == 4:
181
+ batch_size, channel, height, width = hidden_states.shape
182
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
183
+
184
+ batch_size, sequence_length, _ = (
185
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
186
+ )
187
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
188
+
189
+ if attn.group_norm is not None:
190
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
191
+
192
+ query = attn.to_q(hidden_states)
193
+
194
+ if encoder_hidden_states is None:
195
+ encoder_hidden_states = hidden_states
196
+ else:
197
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
198
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
199
+ if attn.norm_cross:
200
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
201
+
202
+ key_0 = attn.to_k(encoder_hidden_states)
203
+ value_0 = attn.to_v(encoder_hidden_states)
204
+
205
+ query = attn.head_to_batch_dim(query)
206
+ key = attn.head_to_batch_dim(key_0)
207
+ value = attn.head_to_batch_dim(value_0)
208
+
209
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
210
+ hidden_states = torch.bmm(attention_probs, value)
211
+ hidden_states = attn.batch_to_head_dim(hidden_states)
212
+
213
+ if self.enabled:
214
+ if self.mode == 'extract':
215
+ ks, vs = attn.head_to_batch_dim(self.to_k_ip(key_0)), attn.head_to_batch_dim(self.to_v_ip(value_0))
216
+ self.store_ks, self.store_vs = self.store_ks+[ks], self.store_vs+[vs]
217
+ self.store_ks, self.store_vs = torch.cat(self.store_ks,dim=0), torch.cat(self.store_vs,dim=0)
218
+
219
+ if self.mode == 'masked_generation':
220
+ if not self.store_ks.shape[0]==query.shape[0]: self.store_ks,self.store_vs = self.store_ks.repeat(2,1,1), self.store_vs.repeat(2,1,1)
221
+ mask_id = self.mask_id.clone()
222
+ mask_id.masked_fill_(self.mask_id==False, -torch.finfo(mask_id.dtype).max)
223
+ mask_id = rearrange(F.interpolate(mask_id,size=int(math.sqrt(query.shape[1]))),"b c h w -> b c (h w)").repeat(1,query.shape[1],1)
224
+ mask_id = mask_id.repeat(8,1,1) # 8 is head dim
225
+ if not mask_id.shape[0]==int(query.shape[0]): mask_id = mask_id.repeat(2,1,1)
226
+ attention_probs_ref = attn.get_attention_scores(query, self.store_ks, mask_id.to(query.dtype))
227
+ hidden_states_ref = torch.bmm(attention_probs_ref, self.store_vs)
228
+ hidden_states_ref = attn.batch_to_head_dim(hidden_states_ref)
229
+ scale = self.scale.repeat(int(batch_size/self.scale.shape[0])).unsqueeze(-1).unsqueeze(-1) if type(self.scale)==torch.Tensor else self.scale
230
+ if self.mask_ig == None:
231
+ hidden_states = hidden_states + scale * hidden_states_ref * self.scale_learnable
232
+ else:
233
+ mask_ig = rearrange(F.interpolate(self.mask_ig,size=int(math.sqrt(query.shape[1]))),"b c h w -> b (h w) c")
234
+ if not mask_ig.shape[0]==hidden_states_ref.shape[0]: mask_ig = mask_ig.repeat(2,1,1)
235
+ hidden_states = hidden_states + scale * hidden_states_ref * mask_ig * self.scale_learnable
236
+
237
+ # linear proj
238
+ hidden_states = attn.to_out[0](hidden_states)
239
+ # dropout
240
+ hidden_states = attn.to_out[1](hidden_states)
241
+
242
+ if input_ndim == 4:
243
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
244
+
245
+ return hidden_states
model_lib/moMA_generator.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
4
+ from PIL import Image
5
+ from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
6
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
7
+ import tqdm
8
+
9
+
10
+ def get_subject_idx(model,prompt,src_subject,device):
11
+ tokenized_prompt = model.tokenizer(prompt,padding="max_length",max_length=model.tokenizer.model_max_length,truncation=True,return_tensors="pt",).to(device)
12
+ input_ids = tokenized_prompt['input_ids']
13
+ src_subject_idxs = []
14
+ for subject,input_id in zip(src_subject,input_ids):
15
+ src_subject_token_id = [model.tokenizer.encode(i, add_special_tokens=False)[0] for i in subject.split(' ')]
16
+ src_subject_idxs = [i for i, x in enumerate(input_id.tolist()) if x in src_subject_token_id]
17
+ return [src_subject_idxs]
18
+
19
+
20
+ def add_function(model):
21
+ @torch.no_grad()
22
+ def generate_with_adapters(
23
+ model,
24
+ prompt_embeds,
25
+ num_inference_steps,
26
+ generator,
27
+ t_range=list(range(0,950)),
28
+ ):
29
+
30
+ latents = model.prepare_latents(prompt_embeds.shape[0]//2,4,512,512,prompt_embeds.dtype,prompt_embeds.device,generator)
31
+
32
+ model.scheduler.set_timesteps(num_inference_steps)
33
+
34
+ iterator = tqdm.tqdm(model.scheduler.timesteps)
35
+ mask_ig_prev = None
36
+ for i, t in enumerate(iterator):
37
+ if not t in t_range:
38
+ model.moMA_generator.toggle_enable_flag('cross')
39
+ else:
40
+ model.moMA_generator.toggle_enable_flag('all')
41
+
42
+ latent_model_input = torch.cat([latents] * 2)
43
+ noise_pred = model.unet(
44
+ latent_model_input,
45
+ t,
46
+ encoder_hidden_states=prompt_embeds,
47
+ return_dict=False,
48
+ )[0]
49
+
50
+ # perform guidance
51
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
52
+ noise_pred = noise_pred_uncond + 7.5 * (noise_pred_text - noise_pred_uncond)
53
+
54
+ latents = model.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
55
+
56
+ mask_ig_prev = (get_mask_from_cross(model.unet.attn_processors))[latents.shape[0]:]
57
+
58
+ model.moMA_generator.set_self_mask('self','ig',mask_ig_prev)
59
+ model.moMA_generator.set_self_mask('cross',mask=mask_ig_prev.clone().detach())
60
+
61
+ image = model.vae.decode(latents / model.vae.config.scaling_factor, return_dict=False)[0]
62
+ return image ,mask_ig_prev.repeat(1,3,1,1) if (not mask_ig_prev==None) else None
63
+ model.generate_with_adapters = generate_with_adapters
64
+
65
+
66
+ class ImageProjModel(torch.nn.Module):
67
+ """Projection Model"""
68
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
69
+ super().__init__()
70
+
71
+ self.cross_attention_dim = cross_attention_dim
72
+ self.clip_extra_context_tokens = clip_extra_context_tokens
73
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
74
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
75
+
76
+ def forward(self, image_embeds):
77
+ embeds = image_embeds
78
+ clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
79
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
80
+ return clip_extra_context_tokens
81
+
82
+
83
+ class MoMA_generator:
84
+ def __init__(self, device,args):
85
+ self.args = args
86
+ self.device = device
87
+
88
+ noise_scheduler = DDIMScheduler(num_train_timesteps=1000,beta_start=0.00085,beta_end=0.012,beta_schedule="scaled_linear",clip_sample=False,set_alpha_to_one=False,steps_offset=1,)
89
+
90
+ print('Loading VAE: stabilityai--sd-vae-ft-mse...')
91
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
92
+
93
+ print('Loading StableDiffusion: Realistic_Vision...')
94
+ self.pipe = StableDiffusionPipeline.from_pretrained(
95
+ "SG161222/Realistic_Vision_V4.0_noVAE",
96
+ torch_dtype=torch.float16,
97
+ scheduler=noise_scheduler,
98
+ vae=vae,
99
+ feature_extractor=None,
100
+ safety_checker=None,
101
+ ).to(self.device)
102
+
103
+ self.unet = self.pipe.unet
104
+ add_function(self.pipe)
105
+ self.pipe.moMA_generator = self
106
+
107
+ self.set_ip_adapter()
108
+ self.image_proj_model = self.init_proj()
109
+
110
+ def init_proj(self):
111
+ image_proj_model = ImageProjModel(
112
+ cross_attention_dim=768,
113
+ clip_embeddings_dim=1024,
114
+ clip_extra_context_tokens=4,
115
+ ).to(self.device, dtype=torch.float16)
116
+ return image_proj_model
117
+
118
+ def set_ip_adapter(self):
119
+ unet = self.unet
120
+ attn_procs = {}
121
+ for name in unet.attn_processors.keys():
122
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
123
+ if name.startswith("mid_block"):
124
+ hidden_size = unet.config.block_out_channels[-1]
125
+ elif name.startswith("up_blocks"):
126
+ block_id = int(name[len("up_blocks.")])
127
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
128
+ elif name.startswith("down_blocks"):
129
+ block_id = int(name[len("down_blocks.")])
130
+ hidden_size = unet.config.block_out_channels[block_id]
131
+ if cross_attention_dim is None:
132
+ attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
133
+ else:
134
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
135
+ unet.set_attn_processor(attn_procs)
136
+
137
+ @torch.inference_mode()
138
+ def get_image_embeds_CFG(self, llava_emb):
139
+ clip_image_embeds = llava_emb
140
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
141
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
142
+ return image_prompt_embeds, uncond_image_prompt_embeds
143
+
144
+ def get_image_crossAttn_feature(
145
+ self,
146
+ llava_emb,
147
+ num_samples=1,
148
+ ):
149
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds_CFG(llava_emb)
150
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
151
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
152
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
153
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
154
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
155
+ return image_prompt_embeds, uncond_image_prompt_embeds
156
+
157
+ # feature are from self-attention layers of Unet: feed reference image to Unet with t=0
158
+ def get_image_selfAttn_feature(
159
+ self,
160
+ pil_image,
161
+ prompt,
162
+ ):
163
+ self.toggle_enable_flag('self')
164
+ self.toggle_extract_inject_flag('self', 'extract')
165
+ tokenized_prompt = self.pipe.tokenizer(prompt,padding="max_length",truncation=True,return_tensors="pt",).to(self.device)
166
+ text_embeddings = self.pipe.text_encoder(input_ids=tokenized_prompt.input_ids)[0]
167
+
168
+ ref_image = pil_image
169
+ ref_image.to(self.device)
170
+
171
+ with torch.no_grad(): latents = self.pipe.vae.encode(ref_image).latent_dist.sample()
172
+ latents = latents * self.pipe.vae.config.scaling_factor
173
+
174
+ noise = torch.randn_like(latents)
175
+ timesteps = torch.tensor([0],device=latents.device).long() # fixed to 0
176
+ noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timesteps)
177
+
178
+ _ = self.unet(noisy_latents,timestep=timesteps,encoder_hidden_states=text_embeddings)["sample"]
179
+ # features are stored in attn_processors
180
+
181
+ return None
182
+
183
+ @torch.no_grad()
184
+ def generate_with_MoMA(
185
+ self,
186
+ batch,
187
+ llava_emb=None,
188
+ seed=None,
189
+ device='cuda',
190
+ ):
191
+ self.reset_all()
192
+ img_ig,mask_id,subject,prompt = batch['image'].half().to(device),batch['mask'].half().to(device),batch['label'][0],batch['text'][0]
193
+
194
+ prompt = [f"photo of a {subject}. "+ prompt]
195
+ subject_idx = get_subject_idx(self.pipe,prompt,[subject],self.device)
196
+ negative_prompt = None
197
+
198
+ # get context-cross-attention feature (from MLLM decoder)
199
+ cond_llava_embeds, uncond_llava_embeds = self.get_image_crossAttn_feature(llava_emb,num_samples=1)
200
+ # get subject-cross-attention feature (from Unet)
201
+ self.get_image_selfAttn_feature(img_ig,subject) # features are stored in attn_processors
202
+
203
+ with torch.inference_mode():
204
+ prompt_embeds = self.pipe._encode_prompt(
205
+ prompt, device=self.device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt)
206
+ negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
207
+ prompt_embeds = torch.cat([prompt_embeds_, cond_llava_embeds], dim=1)
208
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_llava_embeds], dim=1)
209
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
210
+
211
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
212
+
213
+ self.set_self_mask('eraseAll')
214
+ self.toggle_enable_flag('all')
215
+ self.toggle_extract_inject_flag('all','masked_generation')
216
+ self.set_self_mask('self','id',mask_id)
217
+ self.set_cross_subject_idxs(subject_idx)
218
+
219
+ images, mask = self.pipe.generate_with_adapters(
220
+ self.pipe,
221
+ prompt_embeds,
222
+ 50,
223
+ generator,
224
+ )
225
+ images = torch.clip((images+1)/2.0,min=0.0,max=1.0)
226
+
227
+ return images.cpu(), mask.cpu()
228
+
229
+ def set_selfAttn_strength(self, strength):
230
+ for attn_processor in self.unet.attn_processors.values():
231
+ if isinstance(attn_processor, IPAttnProcessor):
232
+ attn_processor.scale = 1.0
233
+ if isinstance(attn_processor, IPAttnProcessor_Self):
234
+ attn_processor.scale = strength
235
+
236
+ def set_cross_subject_idxs(self, subject_idxs):
237
+ for attn_processor in self.unet.attn_processors.values():
238
+ if isinstance(attn_processor, IPAttnProcessor):
239
+ attn_processor.subject_idxs = subject_idxs
240
+
241
+ def set_self_mask(self,mode,id_ig='', mask=None): #only have effect on self attn of the generation process
242
+ for attn_processor in self.unet.attn_processors.values():
243
+ if mode == 'eraseAll':
244
+ if isinstance(attn_processor, IPAttnProcessor_Self):
245
+ attn_processor.mask_id,attn_processor.mask_ig = None,None
246
+ if isinstance(attn_processor, IPAttnProcessor):
247
+ attn_processor.mask_i, attn_processor.mask_ig_prev = None, None
248
+ if mode == 'self':
249
+ if isinstance(attn_processor, IPAttnProcessor_Self):
250
+ if id_ig == 'id':attn_processor.mask_id = mask
251
+ if id_ig == 'ig':attn_processor.mask_ig = mask
252
+ if mode == 'cross':
253
+ if isinstance(attn_processor, IPAttnProcessor):
254
+ attn_processor.mask_ig_prev = mask
255
+
256
+ def toggle_enable_flag(self, processor_enable_mode):
257
+ for attn_processor in self.unet.attn_processors.values():
258
+ if processor_enable_mode == 'cross':
259
+ if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = True
260
+ if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = False
261
+ if processor_enable_mode == 'self':
262
+ if isinstance(attn_processor, IPAttnProcessor):attn_processor.enabled = False
263
+ if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.enabled = True
264
+ if processor_enable_mode == 'all':
265
+ attn_processor.enabled = True
266
+ if processor_enable_mode == 'none':
267
+ attn_processor.enabled = False
268
+
269
+ def toggle_extract_inject_flag(self, processor_name, mode): # mode: str, 'extract' or 'inject' or 'both'(cross only)
270
+ for attn_processor in self.unet.attn_processors.values():
271
+ if processor_name == 'cross':
272
+ if isinstance(attn_processor, IPAttnProcessor):attn_processor.mode = mode
273
+ if processor_name == 'self':
274
+ if isinstance(attn_processor, IPAttnProcessor_Self):attn_processor.mode = mode
275
+ if processor_name == 'all':
276
+ attn_processor.mode = mode
277
+
278
+ def reset_all(self,keep_self=False):
279
+ for attn_processor in self.unet.attn_processors.values():
280
+ if isinstance(attn_processor, IPAttnProcessor):
281
+ attn_processor.store_attn, attn_processor.subject_idxs, attn_processor.mask_i, attn_processor.mask_ig_prev, self.subject_idxs = None, None, None, None, None
282
+
283
+ if isinstance(attn_processor, IPAttnProcessor_Self):
284
+ attn_processor.mask_id, attn_processor.mask_ig = None, None
285
+ if not keep_self: attn_processor.store_ks, attn_processor.store_vs = [], []
model_lib/modules.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import List, Optional
6
+ import torch.utils.checkpoint
7
+ from torchvision.transforms import ToPILImage
8
+ from model_lib.moMA_generator import MoMA_generator
9
+ from transformers.activations import ACT2FN
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ from dataset_lib.dataset_eval_MoMA import Dataset_evaluate_MoMA
13
+
14
+ from llava.model.builder import load_pretrained_model
15
+ from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
16
+ from llava.constants import IMAGE_TOKEN_INDEX
17
+
18
+ def add_function(model):
19
+ def my_llava_forward(
20
+ self,
21
+ input_ids: torch.LongTensor = None,
22
+ attention_mask: Optional[torch.Tensor] = None,
23
+ position_ids: Optional[torch.LongTensor] = None,
24
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
25
+ inputs_embeds: Optional[torch.FloatTensor] = None,
26
+ labels: Optional[torch.LongTensor] = None,
27
+ use_cache: Optional[bool] = None,
28
+ output_attentions: Optional[bool] = None,
29
+ output_hidden_states: Optional[bool] = None,
30
+ images: Optional[torch.FloatTensor] = None,
31
+ return_dict: Optional[bool] = None,
32
+ ):
33
+ (_,position_ids,attention_mask,_,inputs_embeds,_) = self.prepare_inputs_labels_for_multimodal(input_ids,position_ids,attention_mask,None,None,images)
34
+
35
+ outputs = self.model(
36
+ input_ids=None,
37
+ attention_mask=attention_mask,
38
+ position_ids=position_ids,
39
+ past_key_values=None,
40
+ inputs_embeds=inputs_embeds,
41
+ use_cache=True,
42
+ output_attentions=False,
43
+ output_hidden_states=False,
44
+ return_dict=True,
45
+ )
46
+ return outputs[0]
47
+
48
+ model.my_llava_forward = my_llava_forward
49
+
50
+
51
+ class LlamaMLP_mapping(nn.Module):
52
+ def __init__(self, hidden_size,hidden_size_out):
53
+ super().__init__()
54
+ self.hidden_size, self.hidden_size_out = hidden_size,hidden_size_out
55
+ self.gate_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
56
+ self.up_proj = nn.Linear(self.hidden_size, self.hidden_size_out, bias=False)
57
+ self.down_proj = nn.Linear(self.hidden_size_out, self.hidden_size_out, bias=False)
58
+ self.act_fn = ACT2FN["silu"]
59
+ self.act_fn_output = ACT2FN["tanh"]
60
+ self.init_linear()
61
+
62
+ def forward(self, x):
63
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
64
+ return down_proj
65
+
66
+ def init_linear(self):
67
+ torch.nn.init.xavier_normal_(self.gate_proj.weight)
68
+ self.gate_proj.weight.data=self.gate_proj.weight.data/4.0
69
+ torch.nn.init.xavier_normal_(self.up_proj.weight)
70
+ self.up_proj.weight.data=self.up_proj.weight.data/4.0
71
+ torch.nn.init.xavier_normal_(self.down_proj.weight)
72
+ self.down_proj.weight.data=self.down_proj.weight.data/4.0
73
+
74
+ class MoMA_main_modal(nn.Module):
75
+ def __init__(self,args):
76
+ super().__init__()
77
+ self.args = args
78
+ self.device = args.device
79
+
80
+ self.moMA_generator = MoMA_generator(self.device,args)
81
+ self.unet = self.moMA_generator.pipe.unet
82
+ self.vae = self.moMA_generator.pipe.vae
83
+
84
+ print('Loading MoMA: its Multi-modal LLM...')
85
+ model_name = get_model_name_from_path(args.model_path)
86
+ self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
87
+
88
+ add_function(self.model_llava)
89
+
90
+ self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
91
+ self.load_saved_components()
92
+ self.freeze_modules()
93
+
94
+ def load_saved_components(self):
95
+ if not os.path.exists(self.args.load_attn_adapters):
96
+ print('Loading Attentions and LLM mappings...')
97
+ hf_hub_download(repo_id=self.args.model_path, filename="attn_adapters_projectors.th",local_dir='/'.join(self.args.load_attn_adapters.split('/')[:-1]))
98
+
99
+ #load attention adapters and self cross attentions
100
+ state_dict = torch.load(self.args.load_attn_adapters, map_location="cpu")
101
+ self.moMA_generator.image_proj_model.load_state_dict(state_dict["projectors"])
102
+ attn_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
103
+ attn_layers.load_state_dict(state_dict["self_cross_attentions"],strict=False)
104
+
105
+ #load LLM projectors
106
+ self.load_state_dict(state_dict['llm_mapping'],strict=False)
107
+
108
+ def freeze_modules(self):
109
+ all_modules = [self.moMA_generator.pipe.vae,self.moMA_generator.pipe.text_encoder,self.unet,self.model_llava,self.mapping]
110
+ for module in all_modules:
111
+ module.train = False
112
+ module.requires_grad_(False)
113
+
114
+ def forward_MLLM(self,batch):
115
+ llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
116
+
117
+ input_ids,attention_masks,position_ids = [],[],[]
118
+ for subject,prompt in zip(subjects,prompts):
119
+ prompt_construct = f"USER: <image>\n A photo of a {subject}. Describe a new image of the same {subject} in: {prompt}. ASSISTANT: *"
120
+ input_id = tokenizer_image_token(prompt_construct, self.tokenizer_llava, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
121
+ attention_mask = torch.ones(input_id.shape, dtype=torch.long, device=self.device)
122
+ position_id = torch.tensor(list(range(input_id.shape[-1])), device=self.device)
123
+
124
+ position_ids += [position_id]
125
+ attention_masks += [attention_mask[0]]
126
+ input_ids += [input_id[0]]
127
+
128
+ input_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in input_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
129
+ position_ids = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in position_ids],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
130
+ attention_masks = torch.nn.utils.rnn.pad_sequence([i.flip(dims=[-1]) for i in attention_masks],batch_first=True,padding_value=self.tokenizer_llava.pad_token_id).flip(dims=[1])
131
+
132
+ output = self.model_llava.my_llava_forward(self.model_llava,input_ids=input_ids,attention_mask=attention_masks,position_ids=position_ids,images=llava_processeds)
133
+ output = self.mapping(output)
134
+ return output[:,-1,:]
135
+
136
+ def reset(self):
137
+ self.moMA_generator.reset_all()
138
+
139
+ def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
140
+ batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
141
+ self.moMA_generator.set_selfAttn_strength(strength)
142
+
143
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True):
144
+ with torch.no_grad():
145
+ ### key steps
146
+ llava_emb = self.forward_MLLM(batch).clone().detach()
147
+ img,mask = self.moMA_generator.generate_with_MoMA(batch,llava_emb=llava_emb,seed=seed,device=self.args.device)
148
+ self.reset()
149
+
150
+ result = ToPILImage()(img[0])
151
+ return result
model_lib/utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from torchvision.transforms import ToPILImage
4
+ from PIL import Image
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description="Simple example of MoMA.")
8
+ parser.add_argument("--load_attn_adapters",type=str,default="checkpoints/attn_adapters_projectors.th",help="self_cross attentions and LLM projectors.")
9
+ parser.add_argument("--output_path",type=str,default="output",help="output directory.")
10
+ parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
11
+ args = parser.parse_known_args()[0]
12
+ args.device = torch.device("cuda", 0)
13
+ args.load_8bit, args.load_4bit = False, True
14
+ return args
15
+
16
+ def show_PIL_image(tensor):
17
+ # tensor of shape [3, 3, 512, 512]
18
+ to_pil = ToPILImage()
19
+ images = [to_pil(tensor[i]) for i in range(tensor.shape[0])]
20
+
21
+ concatenated_image = Image.new('RGB', (images[0].width * 3, images[0].height))
22
+ x_offset = 0
23
+ for img in images:
24
+ concatenated_image.paste(img, (x_offset, 0))
25
+ x_offset += img.width
26
+
27
+ return concatenated_image
output/car_A car in autumn with falling leaves..jpg ADDED
output/car_A wooden sculpture of a car on the table..jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip
2
+ einops
3
+ fastapi
4
+ gradio
5
+ numpy
6
+ requests
7
+ sentencepiece
8
+ tokenizers>=0.12.1
9
+ torch==2.0.1
10
+ torchvision==0.15.2
11
+ uvicorn
12
+ wandb
13
+ shortuuid
14
+ httpx==0.24.0
15
+ deepspeed
16
+ peft==0.4.0
17
+ transformers==4.36.2
18
+ accelerate==0.21.0
19
+ bitsandbytes==0.41.0
20
+ scikit-learn==1.2.2
21
+ sentencepiece==0.1.99
22
+ einops==0.6.1
23
+ einops-exts==0.0.4
24
+ timm==0.6.13
25
+ gradio_client
26
+ opencv-python
27
+ diffusers
28
+ torchaudio
29
+ torchmetrics
30
+ llava-torch
31
+ rembg
32
+ pytorch_lightning