Introduce a custom Sentence Transformer module for smooth multi-modality (#1)
Browse files- Introduce custom Sentence Transformer module (9862f98edfbc3c5f1a56b1a00ef87ad1b9af3b76)
- Use self.max_seq_length to inform the maximum tokenize length (c0c6d64415a1e25865af6dbb702ac5ba5a1645e4)
- Merge branch 'main' into pr/1, resolve merge conflict (008f2574a5989c788b0fa395d4342a0e1c40f250)
- README.md +11 -74
- custom_st.py +87 -0
- modules.json +12 -6
- sentence_bert_config.json +4 -1
README.md
CHANGED
@@ -9004,66 +9004,10 @@ Actually, I've got first place on MTEB (Chinese and English), I will not release
|
|
9004 |
|
9005 |
## Usage
|
9006 |
```python
|
9007 |
-
|
9008 |
-
import functools
|
9009 |
-
import PIL
|
9010 |
-
import numpy as np
|
9011 |
import torch
|
9012 |
-
from typing import Dict
|
9013 |
-
from io import BytesIO
|
9014 |
-
from transformers import SiglipImageProcessor
|
9015 |
from sentence_transformers import SentenceTransformer
|
9016 |
|
9017 |
|
9018 |
-
def jasper_vl_forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
|
9019 |
-
trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
|
9020 |
-
if "pixel_values" in features:
|
9021 |
-
trans_features["pixel_values"] = features["pixel_values"]
|
9022 |
-
sentence_embedding = self.auto_model(**trans_features, **kwargs)["sentence_embedding"]
|
9023 |
-
features.update({"sentence_embedding": sentence_embedding})
|
9024 |
-
return features
|
9025 |
-
|
9026 |
-
|
9027 |
-
def jasper_vl_tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
|
9028 |
-
img_start_token = "<|jasper_img_start|>"
|
9029 |
-
img_token = "<|jasper_img_token|>"
|
9030 |
-
img_end_token = "<|jasper_img_end|>"
|
9031 |
-
num_img_tokens = 300
|
9032 |
-
|
9033 |
-
def process_text_item(item):
|
9034 |
-
if isinstance(item, str):
|
9035 |
-
return item, []
|
9036 |
-
text, images = "", []
|
9037 |
-
for sub_item in item:
|
9038 |
-
if sub_item["type"] == "text":
|
9039 |
-
text += sub_item["content"]
|
9040 |
-
elif sub_item["type"] == "image_bytes":
|
9041 |
-
text += img_start_token + img_token * num_img_tokens + img_end_token
|
9042 |
-
images.append(PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB"))
|
9043 |
-
elif sub_item["type"] == "image_path":
|
9044 |
-
text += img_start_token + img_token * num_img_tokens + img_end_token
|
9045 |
-
images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
|
9046 |
-
else:
|
9047 |
-
raise ValueError(f"unknown data type {sub_item['type']}")
|
9048 |
-
return text, images
|
9049 |
-
|
9050 |
-
all_texts, all_images = [], []
|
9051 |
-
for item in texts:
|
9052 |
-
text, images = process_text_item(item)
|
9053 |
-
all_texts.append(text)
|
9054 |
-
all_images.extend(images)
|
9055 |
-
ipt = self.tokenizer(all_texts, padding="longest", truncation=True, max_length=1024, return_tensors="pt")
|
9056 |
-
if all_images:
|
9057 |
-
ipt["pixel_values"] = self.processor(
|
9058 |
-
images=all_images,
|
9059 |
-
return_tensors="pt"
|
9060 |
-
)["pixel_values"]
|
9061 |
-
# For the sake of demonstration, external variables are used here, please modify the code according to your own environment.
|
9062 |
-
if use_gpu:
|
9063 |
-
ipt["pixel_values"] = ipt["pixel_values"].bfloat16()
|
9064 |
-
return ipt
|
9065 |
-
|
9066 |
-
|
9067 |
DOC1 = """
|
9068 |
Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere.
|
9069 |
Blue is scattered more than other colors because it travels as shorter, smaller waves. This is why we see a blue sky most of the time.
|
@@ -9081,10 +9025,6 @@ Color combinations: Decide how to best complement your preferred color with othe
|
|
9081 |
Color palette: Limit your color palette to a main color and one or two additional colors.
|
9082 |
60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
|
9083 |
"""
|
9084 |
-
prompt_dict = {
|
9085 |
-
"s2p_query": "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ",
|
9086 |
-
"s2s_query": "Instruct: Retrieve semantically similar text.\nQuery: "
|
9087 |
-
}
|
9088 |
if __name__ == "__main__":
|
9089 |
# load model
|
9090 |
use_gpu = False
|
@@ -9092,7 +9032,7 @@ if __name__ == "__main__":
|
|
9092 |
model = SentenceTransformer(
|
9093 |
model_name,
|
9094 |
trust_remote_code=True,
|
9095 |
-
device="cpu",
|
9096 |
model_kwargs={
|
9097 |
"torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
|
9098 |
"attn_implementation": "sdpa"
|
@@ -9101,13 +9041,10 @@ if __name__ == "__main__":
|
|
9101 |
## 1024 is recommended
|
9102 |
# set is_text_encoder 'True', if you do not encode image
|
9103 |
config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
|
9104 |
-
tokenizer_kwargs={"padding_side": "right"}
|
9105 |
)
|
9106 |
-
#
|
9107 |
-
model.processor = SiglipImageProcessor.from_pretrained(model_name)
|
9108 |
-
model.tokenize = functools.partial(jasper_vl_tokenize, model)
|
9109 |
-
model._first_module().forward = functools.partial(jasper_vl_forward, model._first_module())
|
9110 |
model.max_seq_length = 1024
|
|
|
9111 |
# data
|
9112 |
q_list = [
|
9113 |
"Why the sky is blue?",
|
@@ -9118,16 +9055,16 @@ if __name__ == "__main__":
|
|
9118 |
[{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
|
9119 |
DOC2,
|
9120 |
[{"type": "image_path", "content": "./assets/img2.png"}],
|
9121 |
-
|
9122 |
]
|
9123 |
-
q_vecs = model.encode(
|
9124 |
-
doc_vecs = model.encode(doc_list
|
9125 |
-
print(np.matmul(q_vecs, doc_vecs.T))
|
9126 |
-
# the output is:
|
9127 |
-
# [[0.777521 0.75944513 0.24291277 0.2187205]
|
9128 |
-
# [0.32261407 0.30536035 0.74208796 0.5484469]]
|
9129 |
-
|
9130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
9131 |
```
|
9132 |
|
9133 |
## Evaluation on MTEB
|
|
|
9004 |
|
9005 |
## Usage
|
9006 |
```python
|
|
|
|
|
|
|
|
|
9007 |
import torch
|
|
|
|
|
|
|
9008 |
from sentence_transformers import SentenceTransformer
|
9009 |
|
9010 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9011 |
DOC1 = """
|
9012 |
Blue light is scattered in all directions by the tiny molecules of air in Earth's atmosphere.
|
9013 |
Blue is scattered more than other colors because it travels as shorter, smaller waves. This is why we see a blue sky most of the time.
|
|
|
9025 |
Color palette: Limit your color palette to a main color and one or two additional colors.
|
9026 |
60-30-10 rule: Use a primary color 60% of the time, a secondary color 30% of the time, and an accent color 10% of the time
|
9027 |
"""
|
|
|
|
|
|
|
|
|
9028 |
if __name__ == "__main__":
|
9029 |
# load model
|
9030 |
use_gpu = False
|
|
|
9032 |
model = SentenceTransformer(
|
9033 |
model_name,
|
9034 |
trust_remote_code=True,
|
9035 |
+
device="cpu" if not use_gpu else "cuda",
|
9036 |
model_kwargs={
|
9037 |
"torch_dtype": torch.bfloat16 if use_gpu else torch.float32,
|
9038 |
"attn_implementation": "sdpa"
|
|
|
9041 |
## 1024 is recommended
|
9042 |
# set is_text_encoder 'True', if you do not encode image
|
9043 |
config_kwargs={"is_text_encoder": False, "vector_dim": 1024},
|
|
|
9044 |
)
|
9045 |
+
# We can reduce the max_seq_length from the default of 2048 for faster encoding
|
|
|
|
|
|
|
9046 |
model.max_seq_length = 1024
|
9047 |
+
|
9048 |
# data
|
9049 |
q_list = [
|
9050 |
"Why the sky is blue?",
|
|
|
9055 |
[{"type": "image_path", "content": "./assets/img1.png"}, {"type": "text", "content": "Hope this image helps!"}],
|
9056 |
DOC2,
|
9057 |
[{"type": "image_path", "content": "./assets/img2.png"}],
|
|
|
9058 |
]
|
9059 |
+
q_vecs = model.encode(q_list, prompt_name="s2p_query")
|
9060 |
+
doc_vecs = model.encode(doc_list)
|
|
|
|
|
|
|
|
|
|
|
9061 |
|
9062 |
+
# calculate similarity
|
9063 |
+
similarities = model.similarity(q_vecs, doc_vecs)
|
9064 |
+
print(similarities)
|
9065 |
+
# the output is:
|
9066 |
+
# tensor([[0.7775, 0.7594, 0.2429, 0.2187],
|
9067 |
+
# [0.3226, 0.3054, 0.7421, 0.5484]])
|
9068 |
```
|
9069 |
|
9070 |
## Evaluation on MTEB
|
custom_st.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
import PIL
|
3 |
+
import torch
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
from typing import Dict
|
7 |
+
from io import BytesIO
|
8 |
+
from transformers import SiglipImageProcessor
|
9 |
+
from sentence_transformers.models import Transformer as BaseTransformer
|
10 |
+
|
11 |
+
|
12 |
+
class MultiModalTransformer(BaseTransformer):
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
model_name_or_path: str,
|
17 |
+
cache_dir: Optional[str] = None,
|
18 |
+
tokenizer_args: Optional[Dict[str, Any]] = None,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
super().__init__(model_name_or_path, **kwargs)
|
22 |
+
if tokenizer_args is None:
|
23 |
+
tokenizer_args = {}
|
24 |
+
self.processor = SiglipImageProcessor.from_pretrained(
|
25 |
+
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self, features: dict[str, torch.Tensor], **kwargs
|
30 |
+
) -> dict[str, torch.Tensor]:
|
31 |
+
trans_features = {
|
32 |
+
"input_ids": features["input_ids"],
|
33 |
+
"attention_mask": features["attention_mask"],
|
34 |
+
}
|
35 |
+
if "pixel_values" in features:
|
36 |
+
trans_features["pixel_values"] = features["pixel_values"].to(
|
37 |
+
self.auto_model.dtype
|
38 |
+
)
|
39 |
+
|
40 |
+
sentence_embedding = self.auto_model(**trans_features, **kwargs)[
|
41 |
+
"sentence_embedding"
|
42 |
+
]
|
43 |
+
features.update({"sentence_embedding": sentence_embedding})
|
44 |
+
return features
|
45 |
+
|
46 |
+
def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]:
|
47 |
+
img_start_token = "<|jasper_img_start|>"
|
48 |
+
img_token = "<|jasper_img_token|>"
|
49 |
+
img_end_token = "<|jasper_img_end|>"
|
50 |
+
num_img_tokens = 300
|
51 |
+
|
52 |
+
def process_text_item(item):
|
53 |
+
if isinstance(item, str):
|
54 |
+
return item, []
|
55 |
+
text, images = "", []
|
56 |
+
for sub_item in item:
|
57 |
+
if sub_item["type"] == "text":
|
58 |
+
text += sub_item["content"]
|
59 |
+
elif sub_item["type"] == "image_bytes":
|
60 |
+
text += img_start_token + img_token * num_img_tokens + img_end_token
|
61 |
+
images.append(
|
62 |
+
PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB")
|
63 |
+
)
|
64 |
+
elif sub_item["type"] == "image_path":
|
65 |
+
text += img_start_token + img_token * num_img_tokens + img_end_token
|
66 |
+
images.append(PIL.Image.open(sub_item["content"]).convert("RGB"))
|
67 |
+
else:
|
68 |
+
raise ValueError(f"unknown data type {sub_item['type']}")
|
69 |
+
return text, images
|
70 |
+
|
71 |
+
all_texts, all_images = [], []
|
72 |
+
for item in texts:
|
73 |
+
text, images = process_text_item(item)
|
74 |
+
all_texts.append(text)
|
75 |
+
all_images.extend(images)
|
76 |
+
ipt = self.tokenizer(
|
77 |
+
all_texts,
|
78 |
+
padding="longest",
|
79 |
+
truncation=True,
|
80 |
+
max_length=self.max_seq_length,
|
81 |
+
return_tensors="pt",
|
82 |
+
)
|
83 |
+
if all_images:
|
84 |
+
ipt["pixel_values"] = self.processor(
|
85 |
+
images=all_images, return_tensors="pt"
|
86 |
+
)["pixel_values"]
|
87 |
+
return ipt
|
modules.json
CHANGED
@@ -1,8 +1,14 @@
|
|
1 |
[
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
]
|
|
|
1 |
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "custom_st.MultiModalTransformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Normalize",
|
12 |
+
"type": "sentence_transformers.models.Normalize"
|
13 |
+
}
|
14 |
]
|
sentence_bert_config.json
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
{
|
2 |
"max_seq_length": 2048,
|
3 |
-
"do_lower_case": false
|
|
|
|
|
|
|
4 |
}
|
|
|
1 |
{
|
2 |
"max_seq_length": 2048,
|
3 |
+
"do_lower_case": false,
|
4 |
+
"tokenizer_args": {
|
5 |
+
"padding_side": "right"
|
6 |
+
}
|
7 |
}
|