LanguageBind commited on
Commit
cbfb9b8
1 Parent(s): 4cee86a

Update llava/model/builder.py

Browse files
Files changed (1) hide show
  1. llava/model/builder.py +16 -14
llava/model/builder.py CHANGED
@@ -139,6 +139,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
139
  if 'llava' in model_name.lower():
140
  mm_use_x_start_end = getattr(model.config, "mm_use_x_start_end", False)
141
  mm_use_x_patch_token = getattr(model.config, "mm_use_x_patch_token", True)
 
142
  X = model.config.X
143
  if mm_use_x_patch_token:
144
  for x in X:
@@ -146,23 +147,24 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
146
  if mm_use_x_start_end:
147
  for x in X:
148
  tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
 
149
  model.resize_token_embeddings(len(tokenizer))
150
  print(X)
151
- if 'Image' in X:
152
- image_tower = model.get_image_tower()
153
- if not image_tower.is_loaded:
154
- image_tower.load_model()
155
- image_tower.to(device=device, dtype=torch.float16)
156
- image_processor = image_tower.image_processor
157
- processor['image'] = image_processor
158
 
159
- if 'Video' in X:
160
- video_tower = model.get_video_tower()
161
- if not video_tower.is_loaded:
162
- video_tower.load_model()
163
- video_tower.to(device=device, dtype=torch.float16)
164
- video_processor = video_tower.video_processor
165
- processor['video'] = video_processor
166
 
167
  if hasattr(model.config, "max_sequence_length"):
168
  context_len = model.config.max_sequence_length
 
139
  if 'llava' in model_name.lower():
140
  mm_use_x_start_end = getattr(model.config, "mm_use_x_start_end", False)
141
  mm_use_x_patch_token = getattr(model.config, "mm_use_x_patch_token", True)
142
+ '''
143
  X = model.config.X
144
  if mm_use_x_patch_token:
145
  for x in X:
 
147
  if mm_use_x_start_end:
148
  for x in X:
149
  tokenizer.add_tokens([DEFAULT_X_START_TOKEN[x.upper()], DEFAULT_X_END_TOKEN[x.upper()]], special_tokens=True)
150
+ '''
151
  model.resize_token_embeddings(len(tokenizer))
152
  print(X)
153
+ #if 'Image' in X:
154
+ image_tower = model.get_image_tower()
155
+ if not image_tower.is_loaded:
156
+ image_tower.load_model()
157
+ image_tower.to(device=device, dtype=torch.float16)
158
+ image_processor = image_tower.image_processor
159
+ processor['image'] = image_processor
160
 
161
+ #if 'Video' in X:
162
+ video_tower = model.get_video_tower()
163
+ if not video_tower.is_loaded:
164
+ video_tower.load_model()
165
+ video_tower.to(device=device, dtype=torch.float16)
166
+ video_processor = video_tower.video_processor
167
+ processor['video'] = video_processor
168
 
169
  if hasattr(model.config, "max_sequence_length"):
170
  context_len = model.config.max_sequence_length