Kunpeng Song commited on
Commit
18976e3
1 Parent(s): eefa462
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -15,18 +15,15 @@ device = torch.device('cuda')
15
  seed_everything(0)
16
  args = parse_args()
17
 
18
- def MoMA_demo(rgb, subject, prompt, strength, seed):
19
- from model_lib.modules import MoMA_main_modal
20
- model = MoMA_main_modal(args).to(device, dtype=torch.float16)
21
- generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
22
- return generated_image
23
 
24
  @spaces.GPU
25
  def inference(rgb, subject, prompt, strength, seed):
26
  seed = int(seed) if seed else 0
27
  seed = seed if not seed == 0 else np.random.randint(0,1000)
28
- result = MoMA_demo(rgb, subject, prompt, strength, seed)
29
- return result
30
 
31
  gr.Interface(
32
  inference,
 
15
  seed_everything(0)
16
  args = parse_args()
17
 
18
+ from model_lib.modules import MoMA_main_modal
19
+ model = MoMA_main_modal(args).to(device, dtype=torch.float16)
 
 
 
20
 
21
  @spaces.GPU
22
  def inference(rgb, subject, prompt, strength, seed):
23
  seed = int(seed) if seed else 0
24
  seed = seed if not seed == 0 else np.random.randint(0,1000)
25
+ generated_image = model.generate_images(rgb, subject, prompt, strength=strength, seed=seed)
26
+ return generated_image
27
 
28
  gr.Interface(
29
  inference,
model_lib/moMA_generator.py CHANGED
@@ -1,6 +1,3 @@
1
- import spaces
2
-
3
-
4
  import torch
5
  from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
6
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
@@ -98,7 +95,7 @@ class MoMA_generator:
98
  vae=vae,
99
  feature_extractor=None,
100
  safety_checker=None,
101
- )
102
 
103
  self.unet = self.pipe.unet
104
  add_function(self.pipe)
@@ -112,7 +109,7 @@ class MoMA_generator:
112
  cross_attention_dim=768,
113
  clip_embeddings_dim=1024,
114
  clip_extra_context_tokens=4,
115
- )
116
  return image_proj_model
117
 
118
  def set_ip_adapter(self):
@@ -129,9 +126,9 @@ class MoMA_generator:
129
  block_id = int(name[len("down_blocks.")])
130
  hidden_size = unet.config.block_out_channels[block_id]
131
  if cross_attention_dim is None:
132
- attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4)
133
  else:
134
- attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4)
135
  unet.set_attn_processor(attn_procs)
136
 
137
  @torch.inference_mode()
 
 
 
 
1
  import torch
2
  from model_lib.attention_processor import IPAttnProcessor, IPAttnProcessor_Self, get_mask_from_cross
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
 
95
  vae=vae,
96
  feature_extractor=None,
97
  safety_checker=None,
98
+ ).to(self.device)
99
 
100
  self.unet = self.pipe.unet
101
  add_function(self.pipe)
 
109
  cross_attention_dim=768,
110
  clip_embeddings_dim=1024,
111
  clip_extra_context_tokens=4,
112
+ ).to(self.device, dtype=torch.float16)
113
  return image_proj_model
114
 
115
  def set_ip_adapter(self):
 
126
  block_id = int(name[len("down_blocks.")])
127
  hidden_size = unet.config.block_out_channels[block_id]
128
  if cross_attention_dim is None:
129
+ attn_procs[name] = IPAttnProcessor_Self(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
130
  else:
131
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,scale=1.0,num_tokens=4).to(self.device, dtype=torch.float16)
132
  unet.set_attn_processor(attn_procs)
133
 
134
  @torch.inference_mode()
model_lib/modules.py CHANGED
@@ -1,5 +1,3 @@
1
- import spaces
2
-
3
  import os
4
  import torch
5
  import torch.nn as nn
@@ -84,11 +82,11 @@ class MoMA_main_modal(nn.Module):
84
 
85
  print('Loading MoMA: its Multi-modal LLM...')
86
  model_name = get_model_name_from_path(args.model_path)
87
- self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit)
88
 
89
  add_function(self.model_llava)
90
 
91
- self.mapping = LlamaMLP_mapping(4096,1024)
92
  self.load_saved_components()
93
  self.freeze_modules()
94
 
@@ -137,7 +135,6 @@ class MoMA_main_modal(nn.Module):
137
  def reset(self):
138
  self.moMA_generator.reset_all()
139
 
140
- @torch.no_grad()
141
  def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
142
  batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
143
  self.moMA_generator.set_selfAttn_strength(strength)
 
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
 
82
 
83
  print('Loading MoMA: its Multi-modal LLM...')
84
  model_name = get_model_name_from_path(args.model_path)
85
+ self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
86
 
87
  add_function(self.model_llava)
88
 
89
+ self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
90
  self.load_saved_components()
91
  self.freeze_modules()
92
 
 
135
  def reset(self):
136
  self.moMA_generator.reset_all()
137
 
 
138
  def generate_images(self, rgb_path, subject, prompt, strength=1.0, num=1, seed=0):
139
  batch = Dataset_evaluate_MoMA(rgb_path, prompt, subject,self)
140
  self.moMA_generator.set_selfAttn_strength(strength)
model_lib/utils.py CHANGED
@@ -10,7 +10,7 @@ def parse_args():
10
  parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
11
  args = parser.parse_known_args()[0]
12
  args.device = torch.device("cuda", 0)
13
- args.load_8bit, args.load_4bit = False, False
14
  return args
15
 
16
  def show_PIL_image(tensor):
 
10
  parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
11
  args = parser.parse_known_args()[0]
12
  args.device = torch.device("cuda", 0)
13
+ args.load_8bit, args.load_4bit = False, True
14
  return args
15
 
16
  def show_PIL_image(tensor):