Kunpeng Song commited on
Commit
eefa462
1 Parent(s): e997668
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +4 -10
  3. model_lib/moMA_generator.py +0 -1
  4. model_lib/modules.py +1 -1
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -5,35 +5,29 @@ import torch
5
  import numpy as np
6
  import torch
7
  from pytorch_lightning import seed_everything
8
- from model_lib.modules import MoMA_main_modal
9
  from model_lib.utils import parse_args
10
  import os
11
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
12
 
13
  title = "MoMA"
14
- description = "This model has to run on GPU. By default, we load the model with 4-bit quantization to make it fit in smaller hardware."
15
  device = torch.device('cuda')
16
-
17
  seed_everything(0)
18
  args = parse_args()
19
- #load MoMA from HuggingFace. Auto download
20
- model = MoMA_main_modal(args).to(device, dtype=torch.float16)
21
 
22
- @spaces.GPU
23
  def MoMA_demo(rgb, subject, prompt, strength, seed):
24
- with torch.no_grad():
25
- generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
 
26
  return generated_image
27
 
28
  @spaces.GPU
29
  def inference(rgb, subject, prompt, strength, seed):
30
  seed = int(seed) if seed else 0
31
  seed = seed if not seed == 0 else np.random.randint(0,1000)
32
-
33
  result = MoMA_demo(rgb, subject, prompt, strength, seed)
34
  return result
35
 
36
-
37
  gr.Interface(
38
  inference,
39
  [gr.Image(type="pil", label="Input RGB"),
 
5
  import numpy as np
6
  import torch
7
  from pytorch_lightning import seed_everything
 
8
  from model_lib.utils import parse_args
9
  import os
10
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
11
 
12
  title = "MoMA"
13
+ description = "This model has to run on GPU. Please find our project page at https://moma-adapter.github.io/."
14
  device = torch.device('cuda')
 
15
  seed_everything(0)
16
  args = parse_args()
 
 
17
 
 
18
  def MoMA_demo(rgb, subject, prompt, strength, seed):
19
+ from model_lib.modules import MoMA_main_modal
20
+ model = MoMA_main_modal(args).to(device, dtype=torch.float16)
21
+ generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
22
  return generated_image
23
 
24
  @spaces.GPU
25
  def inference(rgb, subject, prompt, strength, seed):
26
  seed = int(seed) if seed else 0
27
  seed = seed if not seed == 0 else np.random.randint(0,1000)
 
28
  result = MoMA_demo(rgb, subject, prompt, strength, seed)
29
  return result
30
 
 
31
  gr.Interface(
32
  inference,
33
  [gr.Image(type="pil", label="Input RGB"),
model_lib/moMA_generator.py CHANGED
@@ -155,7 +155,6 @@ class MoMA_generator:
155
  return image_prompt_embeds, uncond_image_prompt_embeds
156
 
157
  # feature are from self-attention layers of Unet: feed reference image to Unet with t=0
158
- @spaces.GPU
159
  def get_image_selfAttn_feature(
160
  self,
161
  pil_image,
 
155
  return image_prompt_embeds, uncond_image_prompt_embeds
156
 
157
  # feature are from self-attention layers of Unet: feed reference image to Unet with t=0
 
158
  def get_image_selfAttn_feature(
159
  self,
160
  pil_image,
model_lib/modules.py CHANGED
@@ -112,7 +112,6 @@ class MoMA_main_modal(nn.Module):
112
  module.train = False
113
  module.requires_grad_(False)
114
 
115
- @spaces.GPU
116
  def forward_MLLM(self,batch):
117
  llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
118
 
@@ -138,6 +137,7 @@ class MoMA_main_modal(nn.Module):
138
  def reset(self):
139
  self.moMA_generator.reset_all()
140
 
 
141
  def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
142
  batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
143
  self.moMA_generator.set_selfAttn_strength(strength)
 
112
  module.train = False
113
  module.requires_grad_(False)
114
 
 
115
  def forward_MLLM(self,batch):
116
  llava_processeds,subjects,prompts = batch['llava_processed'].half().to(self.device),batch['label'],batch['text']
117
 
 
137
  def reset(self):
138
  self.moMA_generator.reset_all()
139
 
140
+ @torch.no_grad()
141
  def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
142
  batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
143
  self.moMA_generator.set_selfAttn_strength(strength)