Update model/openlamm.py
Browse files- model/openlamm.py +2 -2
model/openlamm.py
CHANGED
@@ -160,7 +160,7 @@ class LAMMPEFTModel(nn.Module):
|
|
160 |
encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip'
|
161 |
self.encoder_pretrain = encoder_pretrain
|
162 |
assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented'
|
163 |
-
encoder_ckpt_path = args['encoder_ckpt_path'] if not encoder_pretrain == 'clip' else '~/.cache/clip/ViT-L-14.pt'
|
164 |
vicuna_ckpt_path = args['vicuna_ckpt_path']
|
165 |
|
166 |
system_header = args['system_header'] if 'system_header' in args else False
|
@@ -176,7 +176,7 @@ class LAMMPEFTModel(nn.Module):
|
|
176 |
|
177 |
# TODO: Make sure the number of vision tokens is correct
|
178 |
if args['encoder_pretrain'].lower() == 'clip':
|
179 |
-
clip_encoder, self.visual_preprocess = load_clip(
|
180 |
self.visual_encoder = clip_encoder.visual
|
181 |
if self.vision_feature_type == 'global': # global feature from CLIP
|
182 |
self.vision_hidden_size = 768
|
|
|
160 |
encoder_pretrain = args['encoder_pretrain'] if 'encoder_pretrain' in args else 'clip'
|
161 |
self.encoder_pretrain = encoder_pretrain
|
162 |
assert encoder_pretrain in ['imagebind', 'clip', 'epcl'], f'Encoder_pretrain: {encoder_pretrain} Not Implemented'
|
163 |
+
encoder_ckpt_path = args['encoder_ckpt_path'] if not encoder_pretrain == 'clip' and not os.path.isfile(args['encoder_ckpt_path']) else '~/.cache/clip/ViT-L-14.pt'
|
164 |
vicuna_ckpt_path = args['vicuna_ckpt_path']
|
165 |
|
166 |
system_header = args['system_header'] if 'system_header' in args else False
|
|
|
176 |
|
177 |
# TODO: Make sure the number of vision tokens is correct
|
178 |
if args['encoder_pretrain'].lower() == 'clip':
|
179 |
+
clip_encoder, self.visual_preprocess = load_clip(encoder_ckpt_path, device=device)
|
180 |
self.visual_encoder = clip_encoder.visual
|
181 |
if self.vision_feature_type == 'global': # global feature from CLIP
|
182 |
self.vision_hidden_size = 768
|