Spaces:
Runtime error
Runtime error
import logging | |
from omegaconf import OmegaConf | |
from lavis.models import registry | |
from lavis.models import load_preprocess | |
from ldm.util import instantiate_from_config | |
def load_blip2_model(cfg, is_eval=False, device="cpu"): | |
model_cls = registry.get_model_class(cfg.model_name) | |
# load preprocess | |
default_cfg = OmegaConf.load(model_cls.default_config_path(cfg.model_type)) | |
default_cfg.model.pretrained = cfg.pretrained | |
if default_cfg.model.image_size != cfg.params.img_size: | |
default_cfg.model.image_size = cfg.params.img_size | |
model = model_cls.from_config(default_cfg.model) | |
model.cfg = default_cfg.model | |
if is_eval: | |
model.eval() | |
if default_cfg is not None: | |
preprocess_cfg = default_cfg.preprocess | |
vis_processors, txt_processors = load_preprocess(preprocess_cfg) | |
else: | |
vis_processors, txt_processors = None, None | |
logging.info( | |
f"""No default preprocess for model {name} ({model_type}). | |
This can happen if the model is not finetuned on downstream datasets, | |
or it is not intended for direct use without finetuning. | |
""" | |
) | |
if device == "cpu" or device == torch.device("cpu"): | |
model = model.float() | |
return model.to(device), vis_processors, txt_processors | |
def load_qformer_model(cfg): | |
blip2_model, vis_processor, txt_processor = load_blip2_model(cfg) | |
q_former = instantiate_from_config(cfg) | |
if blip2_model.query_tokens.shape != q_former.query_tokens.shape: | |
blip2_model.query_tokens = q_former.query_tokens | |
model_name = cfg.params.get('model_name', 'bert-base-uncased') | |
if model_name == 'bert-base-uncased': | |
q_former.load_state_dict(blip2_model.state_dict(), strict=False) | |
return q_former, (vis_processor, txt_processor) | |