visheratin commited on
Commit
955949b
1 Parent(s): 4e5b95c

Upload nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +150 -0
nllb_mrl.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from open_clip import create_model_and_transforms, get_tokenizer
8
+ from PIL import Image
9
+ from transformers import PretrainedConfig, PreTrainedModel
10
+
11
+
12
+ class MatryoshkaNllbClipConfig(PretrainedConfig):
13
+ def __init__(
14
+ self,
15
+ clip_model_name: str = "",
16
+ clip_model_version: str = "",
17
+ target_resolution: int = -1,
18
+ mrl_resolutions: List[int] = [],
19
+ **kwargs,
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.clip_model_name = clip_model_name
23
+ self.clip_model_version = clip_model_version
24
+ self.target_resolution = target_resolution
25
+ self.mrl_resolutions = mrl_resolutions
26
+
27
+
28
+ class MatryoshkaLayer(nn.Module):
29
+ def __init__(self, resolutions: List[int], target_resolution: int = 768):
30
+ super().__init__()
31
+ self.resolutions = resolutions
32
+ self.layers = nn.ModuleDict()
33
+ for resolution in resolutions:
34
+ self.layers[str(resolution)] = nn.Linear(target_resolution, resolution)
35
+
36
+ def forward(self, x, resolution: Union[int, None] = None):
37
+ if resolution is not None:
38
+ if resolution not in self.resolutions:
39
+ raise ValueError(f"Resolution {resolution} not in {self.resolutions}")
40
+ return self.layers[str(resolution)](x)
41
+ outputs = []
42
+ for resolution in self.resolutions:
43
+ outputs.append(self.layers[str(resolution)](x))
44
+ return outputs
45
+
46
+
47
+ class MatryoshkaNllbClip(PreTrainedModel):
48
+ def __init__(self, config: MatryoshkaNllbClipConfig, device):
49
+ super().__init__(config)
50
+ if isinstance(device, str):
51
+ device = torch.device(device)
52
+ self.config = config
53
+ self.model, _, self.transform = create_model_and_transforms(
54
+ config.clip_model_name, config.clip_model_version, output_dict=True
55
+ )
56
+ self._device = device
57
+ self.model.to(device)
58
+ self.matryoshka_layer = MatryoshkaLayer(
59
+ config.mrl_resolutions, config.target_resolution
60
+ )
61
+ self.matryoshka_layer.to(device)
62
+ self.tokenizer = get_tokenizer(config.clip_model_name)
63
+
64
+ def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
65
+ image_inputs = image_inputs.to(self._device)
66
+ input_ids = input_ids.to(self._device)
67
+ outputs = self.model(
68
+ image=image_inputs,
69
+ text=input_ids,
70
+ )
71
+ mrl_image_features = None
72
+ mrl_text_features = None
73
+ if resolution is not None:
74
+ mrl_image_features = self.matryoshka_layer.forward(
75
+ outputs["image_features"], resolution
76
+ )
77
+ mrl_text_features = self.matryoshka_layer.forward(
78
+ outputs["text_features"], resolution
79
+ )
80
+ return {
81
+ "image_features": outputs["image_features"],
82
+ "text_features": outputs["text_features"],
83
+ "mrl_image_features": mrl_image_features,
84
+ "mrl_text_features": mrl_text_features,
85
+ "logit_scale": outputs["logit_scale"],
86
+ "logit_bias": outputs["logit_bias"],
87
+ }
88
+
89
+ def encode_images(
90
+ self,
91
+ images: List[Image.Image],
92
+ normalize=False,
93
+ resolution: Union[int, None] = None,
94
+ ):
95
+ image_inputs = [self.transform(image) for image in images]
96
+ image_inputs = torch.stack(image_inputs, dim=0).to(self._device)
97
+ with torch.inference_mode():
98
+ features = self.model.visual(image_inputs)
99
+ if resolution is not None:
100
+ if resolution not in self.matryoshka_layer.resolutions:
101
+ raise ValueError(
102
+ f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
103
+ )
104
+ features = self.matryoshka_layer.layers[str(resolution)](features)
105
+ return F.normalize(features, dim=-1) if normalize else features
106
+
107
+ def encode_texts(
108
+ self,
109
+ texts: List[str],
110
+ langs: Union[List[str], None] = None,
111
+ normalize=False,
112
+ resolution: Union[int, None] = None,
113
+ ):
114
+ if langs is None:
115
+ langs = ["eng_Latn"] * len(texts)
116
+ texts = [f"{lang}{text}" for lang, text in zip(langs, texts)]
117
+ input_ids = self.tokenizer.tokenizer.batch_encode_plus(
118
+ texts, return_tensors="pt", padding="longest", add_special_tokens=False
119
+ )["input_ids"].to(self._device)
120
+ with torch.inference_mode():
121
+ features = self.model.text(input_ids)
122
+ if resolution is not None:
123
+ if resolution not in self.matryoshka_layer.resolutions:
124
+ raise ValueError(
125
+ f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
126
+ )
127
+ features = self.matryoshka_layer.layers[str(resolution)](features)
128
+ return F.normalize(features, dim=-1) if normalize else features
129
+
130
+ def get_logits(
131
+ self,
132
+ images: List[Image.Image],
133
+ texts: List[str],
134
+ langs: Union[List[str], None] = None,
135
+ resolution: Union[int, None] = None,
136
+ ):
137
+ image_features = self.encode_images(
138
+ images, normalize=True, resolution=resolution
139
+ )
140
+ text_features = self.encode_texts(
141
+ texts, langs, normalize=True, resolution=resolution
142
+ )
143
+ with torch.inference_mode():
144
+ image_logits = (
145
+ self.model.logit_scale.exp() * image_features @ text_features.T
146
+ )
147
+ if self.model.logit_bias is not None:
148
+ image_logits += self.model.logit_bias
149
+ text_logits = image_logits.T
150
+ return image_logits, text_logits