Kunpeng Song commited on
Commit
4fdb6a1
·
1 Parent(s): 57f8019
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +1 -0
  3. model_lib/moMA_generator.py +7 -3
  4. model_lib/modules.py +3 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -19,6 +19,7 @@ args = parse_args()
19
  #load MoMA from HuggingFace. Auto download
20
  model = MoMA_main_modal(args).to(device, dtype=torch.float16)
21
 
 
22
  def MoMA_demo(rgb, subject, prompt, strength, seed):
23
  with torch.no_grad():
24
  generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
 
19
  #load MoMA from HuggingFace. Auto download
20
  model = MoMA_main_modal(args).to(device, dtype=torch.float16)
21
 
22
+ @spaces.GPU
23
  def MoMA_demo(rgb, subject, prompt, strength, seed):
24
  with torch.no_grad():
25
  generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
model_lib/moMA_generator.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
@@ -109,7 +112,7 @@ class MoMA_generator:
109
  cross_attention_dim=768,
110
  clip_embeddings_dim=1024,
111
  clip_extra_context_tokens=4,
112
- ).to(self.device, dtype=torch.float16)
113
  return image_proj_model
114
 
115
  def set_ip_adapter(self):
@@ -126,9 +129,9 @@ class MoMA_generator:
126
  block_id = int(name[len("down_blocks.")])
127
  hidden_size = unet.config.block_out_channels[block_id]
128
  if cross_attention_dim is None:
129
- attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
130
  else:
131
- attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
132
  unet.set_attn_processor(attn_procs)
133
 
134
  @torch.inference_mode()
@@ -152,6 +155,7 @@ class MoMA_generator:
152
  return image_prompt_embeds, uncond_image_prompt_embeds
153
 
154
  # feature are from self-attention layers of Unet: feed reference image to Unet with t=0
 
155
  def get_image_selfAttn_feature(
156
  self,
157
  pil_image,
 
1
+ import spaces
2
+
3
+
4
  import torch
5
  from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
6
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
 
112
  cross_attention_dim=768,
113
  clip_embeddings_dim=1024,
114
  clip_extra_context_tokens=4,
115
+ )
116
  return image_proj_model
117
 
118
  def set_ip_adapter(self):
 
129
  block_id = int(name[len("down_blocks.")])
130
  hidden_size = unet.config.block_out_channels[block_id]
131
  if cross_attention_dim is None:
132
+ attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4)
133
  else:
134
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4)
135
  unet.set_attn_processor(attn_procs)
136
 
137
  @torch.inference_mode()
 
155
  return image_prompt_embeds, uncond_image_prompt_embeds
156
 
157
  # feature are from self-attention layers of Unet: feed reference image to Unet with t=0
158
+ @spaces.GPU
159
  def get_image_selfAttn_feature(
160
  self,
161
  pil_image,
model_lib/modules.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
@@ -110,6 +112,7 @@ class MoMA_main_modal(nn.Module):
110
  module.train = False
111
  module.requires_grad_(False)
112
 
 
113
  def forward_MLLM(self,batch):
114
  llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
115
 
 
1
+ import spaces
2
+
3
  import os
4
  import torch
5
  import torch.nn as nn
 
112
  module.train = False
113
  module.requires_grad_(False)
114
 
115
+ @spaces.GPU
116
  def forward_MLLM(self,batch):
117
  llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
118