Spaces:
Sleeping
Sleeping
| # SOURCE: https://github.com/Sulam-Group/IBYDMT/blob/main/ibydmt/multimodal.py | |
| from abc import abstractmethod | |
| from typing import Mapping, Optional | |
| import clip | |
| import open_clip | |
| from transformers import ( | |
| AlignModel, | |
| AlignProcessor, | |
| BlipForImageTextRetrieval, | |
| BlipProcessor, | |
| FlavaModel, | |
| FlavaProcessor, | |
| ) | |
| from app_lib.config import Config | |
| from app_lib.config import Constants as c | |
| class VisionLanguageModel: | |
| def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
| pass | |
| def encode_text(self, text): | |
| pass | |
| def encode_image(self, image): | |
| pass | |
| models: Mapping[str, VisionLanguageModel] = {} | |
| def register_model(name): | |
| def register(cls: VisionLanguageModel): | |
| if name in models: | |
| raise ValueError(f"Model {name} is already registered") | |
| models[name] = cls | |
| return register | |
| def get_model_name_and_backbone(config: Config): | |
| backbone = config.data.backbone.split(":") | |
| if len(backbone) == 1: | |
| backbone.append(None) | |
| return backbone | |
| def get_model(config: Config, device=c.DEVICE) -> VisionLanguageModel: | |
| model_name, backbone = get_model_name_and_backbone(config) | |
| return models[model_name](backbone, device=device) | |
| def get_text_encoder(config: Config, device=c.DEVICE): | |
| model = get_model(config, device=device) | |
| return model.encode_text | |
| def get_image_encoder(config: Config, device=c.DEVICE): | |
| model = get_model(config, device=device) | |
| return model.encode_image | |
| class CLIPModel(VisionLanguageModel): | |
| def __init__(self, backbone: str, device=c.DEVICE): | |
| self.model, self.preprocess = clip.load(backbone, device=device) | |
| self.tokenize = clip.tokenize | |
| self.device = device | |
| def encode_text(self, text): | |
| text = self.tokenize(text).to(self.device) | |
| return self.model.encode_text(text) | |
| def encode_image(self, image): | |
| image = self.preprocess(image).unsqueeze(0).to(self.device) | |
| return self.model.encode_image(image) | |
| class OpenClipModel(VisionLanguageModel): | |
| OPENCLIP_WEIGHTS = { | |
| "ViT-B-32": "laion2b_s34b_b79k", | |
| "ViT-L-14": "laion2b_s32b_b82k", | |
| } | |
| def __init__(self, backbone: str, device=c.DEVICE): | |
| self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
| backbone, pretrained=self.OPENCLIP_WEIGHTS[backbone], device=device | |
| ) | |
| self.tokenize = open_clip.get_tokenizer(backbone) | |
| self.device = device | |
| def encode_text(self, text): | |
| text = self.tokenize(text).to(self.device) | |
| return self.model.encode_text(text) | |
| def encode_image(self, image): | |
| image = self.preprocess(image).unsqueeze(0).to(self.device) | |
| return self.model.encode_image(image) | |
| class FLAVAModel(VisionLanguageModel): | |
| HF_MODEL = "facebook/flava-full" | |
| def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
| if backbone is None: | |
| backbone = self.HF_MODEL | |
| self.model = FlavaModel.from_pretrained(backbone).to(device) | |
| self.processor = FlavaProcessor.from_pretrained(backbone) | |
| self.device = device | |
| def encode_text(self, text): | |
| text_inputs = self.processor( | |
| text=text, return_tensors="pt", padding="max_length", max_length=77 | |
| ) | |
| text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
| return self.model.get_text_features(**text_inputs)[:, 0, :] | |
| def encode_image(self, image): | |
| image_inputs = self.processor(images=image, return_tensors="pt") | |
| image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} | |
| return self.model.get_image_features(**image_inputs)[:, 0, :] | |
| class ALIGNModel(VisionLanguageModel): | |
| HF_MODEL = "kakaobrain/align-base" | |
| def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
| if backbone is None: | |
| backbone = self.HF_MODEL | |
| self.model = AlignModel.from_pretrained(backbone).to(device) | |
| self.processor = AlignProcessor.from_pretrained(backbone) | |
| self.device = device | |
| def encode_text(self, text): | |
| text_inputs = self.processor( | |
| text=text, return_tensors="pt", padding="max_length", max_length=77 | |
| ) | |
| text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
| return self.model.get_text_features(**text_inputs) | |
| def encode_image(self, image): | |
| image_inputs = self.processor(images=image, return_tensors="pt") | |
| image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} | |
| return self.model.get_image_features(**image_inputs) | |
| class BLIPModel(VisionLanguageModel): | |
| HF_MODEL = "Salesforce/blip-itm-base-coco" | |
| def __init__(self, backbone: Optional[str] = None, device=c.DEVICE): | |
| if backbone is None: | |
| backbone = self.HF_MODEL | |
| self.model = BlipForImageTextRetrieval.from_pretrained(backbone).to(device) | |
| self.processor = BlipProcessor.from_pretrained(backbone) | |
| self.device = device | |
| def encode_text(self, text): | |
| text_inputs = self.processor( | |
| text=text, return_tensors="pt", padding="max_length", max_length=77 | |
| ) | |
| text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} | |
| question_embeds = self.model.text_encoder(**text_inputs)[0] | |
| return self.model.text_proj(question_embeds[:, 0, :]) | |
| def encode_image(self, image): | |
| image_inputs = self.processor(images=image, return_tensors="pt") | |
| image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} | |
| image_embeds = self.model.vision_model(**image_inputs)[0] | |
| return self.model.vision_proj(image_embeds[:, 0, :]) | |