Spaces:
Build error
Build error
import re | |
from transformers import AutoModel | |
# map to check the supported cv archs and also how to extract the model - in some arch, we want to | |
# go through a specific prefix to get to the model as in `model.vision_model` for clip | |
vision_model_name2model = { | |
r"clip": lambda model: model.vision_model, | |
r"vit": lambda model: model, | |
} | |
def vision_model_name_to_model(model_name_or_path, model): | |
"""returns the model if supported, asserts otherwise""" | |
model_name_lowcase = model_name_or_path.lower() | |
for rx, lookup in vision_model_name2model.items(): | |
if re.search(rx, model_name_lowcase): | |
return lookup(model) | |
else: | |
raise ValueError( | |
f"Unknown type of backbone vision model. Got {model_name_or_path}, supported regexes:" | |
f" {list(vision_model_name2model.keys())}." | |
) | |
def get_vision_model(config): | |
vision_model_name = config.vision_model_name | |
vision_model_params = eval(config.vision_model_params) | |
model = AutoModel.from_pretrained(vision_model_name, **vision_model_params) | |
return vision_model_name_to_model(vision_model_name, model) | |