|
--- |
|
license: mit |
|
tags: |
|
- vision |
|
- vision-language-model |
|
- contrastive learning |
|
--- |
|
|
|
**FLAIR Model** |
|
|
|
Authors: [Rui Xiao](https://www.eml-munich.de/people/rui-xiao), [Sanghwan Kim](https://kim-sanghwan.github.io/), [Mariana-Iuliana Georgescu](https://lilygeorgescu.github.io/), [Zeynep Akata](https://www.eml-munich.de/people/zeynep-akata), [Stephan Alaniz](https://www.eml-munich.de/people/stephan-alaniz) |
|
|
|
FLAIR was introduced in the paper [FLAIR: VLM with Fine-grained Language-informed Image Representations](https://arxiv.org/abs/2412.03561). Based on ViT-B-16 Model from [OpenCLIP](https://github.com/mlfoundations/open_clip), FLAIR features text-conditioned attention pooling at the end of its vision transformer. Pre-trained on MLLM-recaptioned datasets from [DreamLIP](https://huggingface.co/datasets/qidouxiong619/dreamlip_long_captions), FALIR achieves strong performance in tasks such as zero-shot image-text retrieval and zero-shot segmentation. |
|
|
|
**Usage** |
|
|
|
We offer the detailed usage in our [Github repo](https://github.com/ExplainableML/flair). Example Usage: |
|
|
|
```python |
|
import flair |
|
from PIL import Image |
|
import torch |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {device}") |
|
|
|
pretrained = flair.download_weights_from_hf(model_repo='xiaorui638/flair', filename='flair-cc3m-recap.pt') |
|
model, _, preprocess = flair.create_model_and_transforms('ViT-B-16-FLAIR', pretrained=pretrained) |
|
|
|
model.to(device) |
|
model.eval() |
|
|
|
tokenizer = flair.get_tokenizer('ViT-B-16-FLAIR') |
|
|
|
image = preprocess(Image.open("../assets/puppy.jpg")).unsqueeze(0).to(device) |
|
|
|
text = tokenizer(["In the image, a small white puppy with black ears and eyes is the main subject", # ground-truth caption |
|
"The white door behind the puppy is closed, and there's a window on the right side of the door", # ground-truth caption |
|
"A red ladybug is surrounded by green glass beads", # non-ground-truth caption |
|
"Dominating the scene is a white desk, positioned against a white brick wall"]).to(device) # non-ground-truth caption |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(): |
|
flair_logits = model.get_logits(image=image, text=text) |
|
clip_logits = model.get_logits_as_clip(image=image, text=text) |
|
|
|
print("logits get using flair's way:", flair_logits) # [4.4062, 6.9531, -20.5000, -18.1719] |
|
print("logits get using clip's way:", clip_logits) # [12.4609, 15.6797, -3.8535, -0.2281] |
|
``` |
|
|
|
As the primary method for FLAIR to generate logits, FLAIR utilizes the text-conditioned attention pooling to pool the local image tokens, generating language-informed image representations. The logits are generated by multiplying with the text features: |
|
|
|
```python |
|
def get_logits(self, image, text): |
|
""" |
|
FLAIR's way ot get the logits. Only used as a minimal example to get the logits, not used in training or inference at this stage |
|
""" |
|
global_image_token, local_image_tokens = self.encode_image(image) |
|
global_text_token, _ = self.encode_text(text) |
|
global_text_token = self.text_post(global_text_token) # (B*K, D) |
|
global_image_token, local_image_tokens = self.image_post(global_image_token), self.image_post( |
|
local_image_tokens) # (B, D), (B, L, D) |
|
batch_size = global_image_token.shape[0] |
|
|
|
# Broadcast the global text token to (B, B*K, D), this is too costly in large-scale training, so we downsample them to (B, B+K-1, D) in training |
|
global_text_token = global_text_token.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
local_image_features = self.visual_proj(global_text_token, local_image_tokens, local_image_tokens) # (B, B*K, D) |
|
|
|
text_features, image_features = F.normalize(global_text_token, dim=-1), F.normalize(local_image_features, dim=-1) |
|
|
|
image_logits = self.logit_scale.exp() * torch.einsum('bij,bij->bi', image_features, text_features) # (B, B*K) |
|
image_logits += self.logit_bias |
|
|
|
text_logits = image_logits.T |
|
|
|
return image_logits, text_logits |
|
``` |
|
|
|
Thanks to the global loss, FLAIR also enforces the matching between global-level image and text features. Therefore, just like the originally CLIP does, FLAIR could also produce logits only considering global image and text features. |
|
|
|
```python |
|
def get_logits_as_clip(self, image, text): |
|
""" |
|
FLAIR could also generate the global-to-global logits as the original CLIP does |
|
""" |
|
global_image_token, _ = self.encode_image(image) |
|
global_text_token, _ = self.encode_text(text) |
|
|
|
|
|
global_image_token = self.image_post(global_image_token) # (B, D) |
|
global_text_token = self.text_post(global_text_token) # (B*K, D) |
|
|
|
image_features, text_features = F.normalize(global_image_token, dim=-1), F.normalize(global_text_token, dim=-1) |
|
|
|
image_logits = self.logit_scale.exp() * image_features @ text_features.t() |
|
text_logits = image_logits.T |
|
|
|
return image_logits, text_logits |
|
``` |
|
|
|
**Citation** |
|
|
|
If you find our work useful, please consider citing: |
|
|
|
```bibtex |
|
@article{xiao2024flair, |
|
title={FLAIR: VLM with Fine-grained Language-informed Image Representations}, |
|
author={Xiao, Rui and Kim, Sanghwan and Georgescu, Mariana-Iuliana and Akata, Zeynep and Alaniz, Stephan}, |
|
journal={arXiv preprint arXiv:2412.03561}, |
|
year={2024} |
|
} |
|
``` |
|
|
|
|