visheratin commited on
Commit
f4f9035
1 Parent(s): 582edb1

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +16 -0
  2. nllb_mrl.py +155 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "nllb_mrl.MatryoshkaNllbClipConfig",
4
+ "AutoModel": "nllb_mrl.MatryoshkaNllbClip"
5
+ },
6
+ "clip_model_name": "nllb-clip-large-siglip",
7
+ "clip_model_version": "v1",
8
+ "mrl_resolutions": [
9
+ 32,
10
+ 64,
11
+ 128,
12
+ 256,
13
+ 512
14
+ ],
15
+ "target_resolution": 1152
16
+ }
nllb_mrl.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from open_clip import create_model, get_tokenizer
7
+ from open_clip.transform import PreprocessCfg, image_transform_v2
8
+ from PIL import Image
9
+ from transformers import PretrainedConfig, PreTrainedModel
10
+
11
+
12
+ class MatryoshkaNllbClipConfig(PretrainedConfig):
13
+ def __init__(
14
+ self,
15
+ clip_model_name: str = "",
16
+ target_resolution: int = -1,
17
+ mrl_resolutions: List[int] = [],
18
+ **kwargs,
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.clip_model_name = clip_model_name
22
+ self.target_resolution = target_resolution
23
+ self.mrl_resolutions = mrl_resolutions
24
+
25
+
26
+ class MatryoshkaLayer(nn.Module):
27
+ def __init__(self, resolutions: List[int], target_resolution: int = 768):
28
+ super().__init__()
29
+ self.resolutions = resolutions
30
+ self.layers = nn.ModuleDict()
31
+ for resolution in resolutions:
32
+ self.layers[str(resolution)] = nn.Linear(target_resolution, resolution)
33
+
34
+ def forward(self, x, resolution: Union[int, None] = None):
35
+ if resolution is not None:
36
+ if resolution not in self.resolutions:
37
+ raise ValueError(f"Resolution {resolution} not in {self.resolutions}")
38
+ return self.layers[str(resolution)](x)
39
+ outputs = []
40
+ for resolution in self.resolutions:
41
+ outputs.append(self.layers[str(resolution)](x))
42
+ return outputs
43
+
44
+
45
+ class MatryoshkaNllbClip(PreTrainedModel):
46
+ config_class = MatryoshkaNllbClipConfig
47
+
48
+ def __init__(self, config: MatryoshkaNllbClipConfig, device):
49
+ super().__init__(config)
50
+ if isinstance(device, str):
51
+ device = torch.device(device)
52
+ self.config = config
53
+ self.model = create_model(
54
+ config.clip_model_name, output_dict=True
55
+ )
56
+ pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
57
+ self.transform = image_transform_v2(
58
+ pp_cfg,
59
+ is_train=False,
60
+ )
61
+ self._device = device
62
+ self.model.to(device)
63
+ self.matryoshka_layer = MatryoshkaLayer(
64
+ config.mrl_resolutions, config.target_resolution
65
+ )
66
+ self.matryoshka_layer.to(device)
67
+ self.tokenizer = get_tokenizer(config.clip_model_name)
68
+
69
+ def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
70
+ image_inputs = image_inputs.to(self._device)
71
+ input_ids = input_ids.to(self._device)
72
+ outputs = self.model(
73
+ image=image_inputs,
74
+ text=input_ids,
75
+ )
76
+ mrl_image_features = None
77
+ mrl_text_features = None
78
+ if resolution is not None:
79
+ mrl_image_features = self.matryoshka_layer.forward(
80
+ outputs["image_features"], resolution
81
+ )
82
+ mrl_text_features = self.matryoshka_layer.forward(
83
+ outputs["text_features"], resolution
84
+ )
85
+ return {
86
+ "image_features": outputs["image_features"],
87
+ "text_features": outputs["text_features"],
88
+ "mrl_image_features": mrl_image_features,
89
+ "mrl_text_features": mrl_text_features,
90
+ "logit_scale": outputs["logit_scale"],
91
+ "logit_bias": outputs["logit_bias"],
92
+ }
93
+
94
+ def encode_images(
95
+ self,
96
+ images: List[Image.Image],
97
+ normalize=False,
98
+ resolution: Union[int, None] = None,
99
+ ):
100
+ image_inputs = [self.transform(image) for image in images]
101
+ image_inputs = torch.stack(image_inputs, dim=0).to(self._device)
102
+ with torch.inference_mode():
103
+ features = self.model.visual(image_inputs)
104
+ if resolution is not None:
105
+ if resolution not in self.matryoshka_layer.resolutions:
106
+ raise ValueError(
107
+ f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
108
+ )
109
+ features = self.matryoshka_layer.layers[str(resolution)](features)
110
+ return F.normalize(features, dim=-1) if normalize else features
111
+
112
+ def encode_texts(
113
+ self,
114
+ texts: List[str],
115
+ langs: Union[List[str], None] = None,
116
+ normalize=False,
117
+ resolution: Union[int, None] = None,
118
+ ):
119
+ if langs is None:
120
+ langs = ["eng_Latn"] * len(texts)
121
+ texts = [f"{lang}{text}" for lang, text in zip(langs, texts)]
122
+ input_ids = self.tokenizer.tokenizer.batch_encode_plus(
123
+ texts, return_tensors="pt", padding="longest", add_special_tokens=False
124
+ )["input_ids"].to(self._device)
125
+ with torch.inference_mode():
126
+ features = self.model.text(input_ids)
127
+ if resolution is not None:
128
+ if resolution not in self.matryoshka_layer.resolutions:
129
+ raise ValueError(
130
+ f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
131
+ )
132
+ features = self.matryoshka_layer.layers[str(resolution)](features)
133
+ return F.normalize(features, dim=-1) if normalize else features
134
+
135
+ def get_logits(
136
+ self,
137
+ images: List[Image.Image],
138
+ texts: List[str],
139
+ langs: Union[List[str], None] = None,
140
+ resolution: Union[int, None] = None,
141
+ ):
142
+ image_features = self.encode_images(
143
+ images, normalize=True, resolution=resolution
144
+ )
145
+ text_features = self.encode_texts(
146
+ texts, langs, normalize=True, resolution=resolution
147
+ )
148
+ with torch.inference_mode():
149
+ image_logits = (
150
+ self.model.logit_scale.exp() * image_features @ text_features.T
151
+ )
152
+ if self.model.logit_bias is not None:
153
+ image_logits += self.model.logit_bias
154
+ text_logits = image_logits.T
155
+ return image_logits, text_logits
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91996767126a42dfaa685555d35a64aa019688c54e07bfa429df7977cc7cd8d1
3
+ size 4786871742