Kunpeng Song commited on
Commit
e997668
1 Parent(s): 4fdb6a1
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. model_lib/modules.py +2 -2
  3. model_lib/utils.py +1 -1
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
model_lib/modules.py CHANGED
@@ -84,11 +84,11 @@ class MoMA_main_modal(nn.Module):
84
 
85
  print('Loading MoMA: its Multi-modal LLM...')
86
  model_name = get_model_name_from_path(args.model_path)
87
- self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit, device=args.device)
88
 
89
  add_function(self.model_llava)
90
 
91
- self.mapping = LlamaMLP_mapping(4096,1024).to(self.device, dtype=torch.float16)
92
  self.load_saved_components()
93
  self.freeze_modules()
94
 
 
84
 
85
  print('Loading MoMA: its Multi-modal LLM...')
86
  model_name = get_model_name_from_path(args.model_path)
87
+ self.tokenizer_llava, self.model_llava, self.image_processor_llava, self.context_len_llava = load_pretrained_model(args.model_path, None, model_name, load_8bit=self.args.load_8bit, load_4bit=self.args.load_4bit)
88
 
89
  add_function(self.model_llava)
90
 
91
+ self.mapping = LlamaMLP_mapping(4096,1024)
92
  self.load_saved_components()
93
  self.freeze_modules()
94
 
model_lib/utils.py CHANGED
@@ -10,7 +10,7 @@ def parse_args():
10
  parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
11
  args = parser.parse_known_args()[0]
12
  args.device = torch.device("cuda", 0)
13
- args.load_8bit, args.load_4bit = False, True
14
  return args
15
 
16
  def show_PIL_image(tensor):
 
10
  parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
11
  args = parser.parse_known_args()[0]
12
  args.device = torch.device("cuda", 0)
13
+ args.load_8bit, args.load_4bit = False, False
14
  return args
15
 
16
  def show_PIL_image(tensor):