|
--- |
|
license: apache-2.0 |
|
--- |
|
# OmniFusion |
|
|
|
**OmniFusion** is an advanced multimodal AI model designed to extend the capabilities of traditional language processing systems by integrating additional data modalities such as images, and potentially audio, 3D and video content. |
|
|
|
### Architecture |
|
|
|
<p align="left"> |
|
<img src="https://raw.githubusercontent.com/AIRI-Institute/OmniFusion/main/content/architecture.png" width="100%"> |
|
</p> |
|
|
|
|
|
OmniFusion open source version core is Mistral-7B. Initially focusing on images, we selected the CLIP-ViT-L as the visual encoder for its efficient information transfer capabilities. The most important component of OmniFusion is its adapter, a mechanism allowing the language model to interpret and incorporate information from different modalities. The adapter is a single-layer, four-headed transformer, which has shown superior performance compared to simpler linear layers or MLP structures. |
|
|
|
This adapter takes embeddings from the visual encoder (excluding the CLS token) and maps them into textual embeddings compatible with the language model. |
|
|
|
To further enhance the model's multimodal capabilities, we employ trainable special tokens to mark the beginning and end of visual data within the text sequence. |
|
|
|
|
|
### Training Process consists of two stages |
|
|
|
1. Pre-training the adapter on Image Captioning tasks (LAION, CC-4M). |
|
2. Once the adapter has learned to map ViT's visual embeddings to the language model's textual space, we proceed to unfreeze Mistral for improved understanding of dialog formats and complex queries. |
|
|
|
<p align="left"> |
|
<img src="https://raw.githubusercontent.com/AIRI-Institute/OmniFusion/main/content/datasets.png" width="70%"> |
|
</p> |
|
|
|
### Results |
|
|
|
OmniFusion was benchmarked against the latest multimodal SOTA models. It excelled in generative metrics and classification benchmarks like VisualDialog. |
|
<p align="left"> |
|
<img src="https://raw.githubusercontent.com/AIRI-Institute/OmniFusion/main/content/radar.png" width="70%"> |
|
</p> |
|
|
|
Model Performance on Visual Dialog Benchmark |
|
|
|
| Model | NDCG | MRR | Recall@1 | Recall@5 | Recall@10 | |
|
| ------------ | ---- | ---- | -------- | -------- | --------- | |
|
| OmniFusion | 25.91| 10.78| 4.74 | 13.80 | 20.53 | |
|
| LLaVA-13B | 24.74| 8.91 | 2.98 | 10.80 | 18.02 | |
|
|
|
### Examples |
|
|
|
<p align="left"> |
|
<img src="https://raw.githubusercontent.com/AIRI-Institute/OmniFusion/main/content/examples.png" width="100%"> |
|
</p> |
|
|
|
### How to Use |
|
|
|
```python |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from urllib.request import urlopen |
|
import torch.nn as nn |
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig |
|
|
|
DEVICE = "cuda:0" |
|
PROMPT = "This is a dialog with AI assistant.\n" |
|
tokenizer = AutoTokenizer.from_pretrained("OmniMistral-tokenizer", use_fast=False) |
|
model = AutoModelForCausalLM.from_pretrained("OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE) |
|
|
|
projection = torch.load("projection", map_location=DEVICE) |
|
special_embs = torch.load("special_embeddings.pt", map_location=DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
class CLIPVisionTower(nn.Module): |
|
def __init__(self, vision_tower, args, delay_load=False): |
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
|
|
self.vision_tower_name = vision_tower |
|
self.select_layer = args.mm_vision_select_layer |
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
|
|
|
if not delay_load: |
|
self.load_model() |
|
else: |
|
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) |
|
|
|
def load_model(self): |
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) |
|
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) |
|
self.vision_tower.requires_grad_(False) |
|
|
|
self.is_loaded = True |
|
|
|
def feature_select(self, image_forward_outs): |
|
image_features = image_forward_outs.hidden_states[self.select_layer] |
|
if self.select_feature == 'patch': |
|
image_features = image_features[:, 1:] |
|
elif self.select_feature == 'cls_patch': |
|
image_features = image_features |
|
else: |
|
raise ValueError(f'Unexpected select feature: {self.select_feature}') |
|
return image_features |
|
|
|
@torch.no_grad() |
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) |
|
image_feature = self.feature_select(image_forward_out).to(image.dtype) |
|
image_features.append(image_feature) |
|
else: |
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) |
|
image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_tower.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.device |
|
|
|
@property |
|
def config(self): |
|
if self.is_loaded: |
|
return self.vision_tower.config |
|
else: |
|
return self.cfg_only |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.config.hidden_size |
|
|
|
|
|
class ClipTowerCfg: |
|
def __init__(self): |
|
self.mm_vision_select_feature = 'patch' |
|
self.mm_vision_select_layer = -2 |
|
|
|
clip = CLIPVisionTower("openai/clip-vit-large-patch14-336", ClipTowerCfg()) |
|
clip.load_model() |
|
clip = clip.to(device=DEVICE, dtype=torch.bfloat16) |
|
|
|
def gen_answer(model, tokenizer, clip, projection, query, special_embs, image=None): |
|
bad_words_ids = tokenizer(["\n", "</s>", ":"], add_special_tokens=False).input_ids + [[13]] |
|
gen_params = { |
|
"do_sample": False, |
|
"max_new_tokens": 50, |
|
"early_stopping": True, |
|
"num_beams": 3, |
|
"repetition_penalty": 1.0, |
|
"remove_invalid_values": True, |
|
"eos_token_id": 2, |
|
"pad_token_id": 2, |
|
"forced_eos_token_id": 2, |
|
"use_cache": True, |
|
"no_repeat_ngram_size": 4, |
|
"bad_words_ids": bad_words_ids, |
|
"num_return_sequences": 1, |
|
} |
|
with torch.no_grad(): |
|
image_features = clip.image_processor(image, return_tensors='pt') |
|
image_embedding = clip(image_features['pixel_values']).to(device=DEVICE, dtype=torch.bfloat16) |
|
|
|
projected_vision_embeddings = projection(image_embedding).to(device=DEVICE, dtype=torch.bfloat16) |
|
prompt_ids = tokenizer.encode(f"{PROMPT}", add_special_tokens=False, return_tensors="pt").to(device=DEVICE) |
|
question_ids = tokenizer.encode(query, add_special_tokens=False, return_tensors="pt").to(device=DEVICE) |
|
|
|
prompt_embeddings = model.model.embed_tokens(prompt_ids).to(torch.bfloat16) |
|
question_embeddings = model.model.embed_tokens(question_ids).to(torch.bfloat16) |
|
|
|
embeddings = torch.cat( |
|
[ |
|
prompt_embeddings, |
|
special_embs['SOI'][None, None, ...], |
|
projected_vision_embeddings, |
|
special_embs['EOI'][None, None, ...], |
|
special_embs['USER'][None, None, ...], |
|
question_embeddings, |
|
special_embs['BOT'][None, None, ...] |
|
], |
|
dim=1, |
|
).to(dtype=torch.bfloat16, device=DEVICE) |
|
out = model.generate(inputs_embeds=embeddings, **gen_params) |
|
out = out[:, 1:] |
|
generated_texts = tokenizer.batch_decode(out)[0] |
|
return generated_texts |
|
|
|
img_url = "https://i.pinimg.com/originals/32/c7/81/32c78115cb47fd4825e6907a83b7afff.jpg" |
|
question = "who is the author?" |
|
img = Image.open(urlopen(img_url)) |
|
|
|
answer = gen_answer( |
|
model, |
|
tokenizer, |
|
clip, |
|
projection, |
|
query=question, |
|
special_embs=special_embs, |
|
image=img |
|
) |
|
|
|
img.show() |
|
print(question) |
|
print(answer) |
|
``` |
|
|
|
### Future Plans |
|
|
|
Work is underway on a version that understands Russian, uses ImageBind encoders, and accepts more modalities (sound, 3D, video). Stay tuned for updates on GitHub! |
|
|
|
### Authors |
|
|
|
The FusionBrain scientific group from the AIRI Institute, in collaboration with scientists from Sber AI, led the model's development. |
|
|
|
Main contributors: |
|
+ Anton Razzhigaev: [Blog](https://t.me/abstractDL) |
|
+ Elizaveta Goncharova |
|
+ Matvey Mihkalchuk |
|
+ Maxim Kurkin |
|
+ Irina Abdullaeva |
|
+ Denis Dimitrov [Blog](https://t.me/dendi_math_ai) |
|
+ Andrey Kuznetsov [Blog](https://t.me/complete_ai) |