sachin commited on
Commit
bcbc05a
1 Parent(s): bf3f3c5

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. clip_config.json +1 -0
  3. models.py +141 -0
  4. text.ckpt +3 -0
  5. vision.ckpt +3 -0
.gitattributes CHANGED
@@ -29,3 +29,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zst filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ text.ckpt filter=lfs diff=lfs merge=lfs -text
33
+ vision.ckpt filter=lfs diff=lfs merge=lfs -text
clip_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cls_token": true, "n_projection_layers": 3, "embed_dims": 512, "vision_model": "edgenext_small", "text_model": "microsoft/xtremedistil-l6-h256-uncased"}
models.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+
4
+ import timm
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import transformers
9
+
10
+
11
+ class Projection(nn.Module):
12
+ def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
13
+ super().__init__()
14
+ self.linear1 = nn.Linear(d_in, d_out, bias=False)
15
+ self.linear2 = nn.Linear(d_out, d_out, bias=False)
16
+ self.layer_norm = nn.LayerNorm(d_out)
17
+ self.drop = nn.Dropout(p)
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ embed1 = self.linear1(x)
21
+ embed2 = self.drop(self.linear2(F.gelu(embed1)))
22
+ embeds = self.layer_norm(embed1 + embed2)
23
+ return embeds
24
+
25
+
26
+ def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module:
27
+ layers = []
28
+ for _ in range(num_layers - 1):
29
+ layers.extend([Projection(d_in, d_in), nn.GELU()])
30
+ layers += [Projection(d_in, d_out)]
31
+ return nn.Sequential(*layers)
32
+
33
+
34
+ def mean_pooling(
35
+ text_representation: torch.FloatTensor, attention_mask: torch.LongTensor
36
+ ) -> torch.FloatTensor:
37
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float()
38
+ return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp(
39
+ input_mask_expanded.sum(1), min=1e-9
40
+ )
41
+
42
+
43
+ class TextEncoder(nn.Module):
44
+ def __init__(
45
+ self,
46
+ base: nn.Module,
47
+ d_in: int,
48
+ d_out: int,
49
+ n_projection_layers: int,
50
+ cls_token: bool = False,
51
+ ):
52
+ super().__init__()
53
+ self.base = base
54
+ self.cls_token = cls_token
55
+ self.projection = projection_layers(d_in, d_out, n_projection_layers)
56
+ self.base.eval()
57
+ for p in self.base.parameters():
58
+ p.requires_grad = False
59
+
60
+ def forward(self, x):
61
+ out = self.base(**x).last_hidden_state
62
+ if self.cls_token:
63
+ out = out[:, 0] # get CLS token output
64
+ else:
65
+ out = mean_pooling(out, x["attention_mask"])
66
+
67
+ projected_vec = self.projection(out)
68
+ return F.normalize(projected_vec, dim=-1)
69
+
70
+
71
+ class VisionEncoder(nn.Module):
72
+ def __init__(self, base: nn.Module, d_in: int, d_out: int, n_projection_layers: int):
73
+ super().__init__()
74
+ self.base = base
75
+ self.projection = projection_layers(d_in, d_out, n_projection_layers)
76
+
77
+ self.base.eval()
78
+ for p in self.base.parameters():
79
+ p.requires_grad = False
80
+
81
+ def forward(self, x):
82
+ projected_vec = self.projection(self.base(x))
83
+ return F.normalize(projected_vec, dim=-1)
84
+
85
+
86
+ class Tokenizer:
87
+ def __init__(self, tokenizer, max_len: int) -> None:
88
+ self.tokenizer = tokenizer
89
+ self.max_len = max_len
90
+
91
+ def __call__(self, x: str) -> transformers.AutoTokenizer:
92
+ return self.tokenizer(
93
+ x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
94
+ )
95
+
96
+ def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
97
+ return [
98
+ self.tokenizer.decode(sentence[:sentence_len])
99
+ for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
100
+ ]
101
+
102
+
103
+ @dataclasses.dataclass(frozen=True)
104
+ class CLIPConfig:
105
+ cls_token: bool = True
106
+ n_projection_layers: int = 3
107
+ embed_dims: int = 512
108
+ vision_model: str = "edgenext_small"
109
+ text_model: str = "microsoft/xtremedistil-l6-h256-uncased"
110
+ max_len: int = 128
111
+
112
+
113
+ def get_model():
114
+ with open("./clip_config.json", "r") as f:
115
+ config = CLIPConfig(**json.load(f))
116
+
117
+ # load text model and tokenizer
118
+ text_config = transformers.AutoConfig.from_pretrained("./text_model_config/")
119
+ text_base = transformers.AutoModel.from_config(text_config)
120
+ tokenizer = Tokenizer(
121
+ transformers.AutoTokenizer.from_pretrained("./tokenizer/"), config.max_len
122
+ )
123
+ text_encoder = TextEncoder(
124
+ text_base,
125
+ text_base.config.hidden_size,
126
+ config.embed_dims,
127
+ config.n_projection_layers,
128
+ config.cls_token,
129
+ )
130
+ text_encoder.load_state_dict(torch.load("./text.ckpt", map_location=torch.device("cpu")))
131
+
132
+ # load vision model and image transform
133
+ image_base = timm.create_model(config.vision_model, num_classes=0)
134
+ timm_config = timm.data.resolve_data_config({}, model=image_base)
135
+ transform = timm.data.transforms_factory.create_transform(**timm_config)
136
+ vision_encoder = VisionEncoder(
137
+ image_base, image_base.num_features, config.embed_dims, config.n_projection_layers
138
+ )
139
+ vision_encoder.load_state_dict(torch.load("./vision.ckpt", map_location=torch.device("cpu")))
140
+
141
+ return text_encoder, tokenizer, vision_encoder, transform
text.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da62e56a0ef10ef6f2d6be37d954da55444043ace6ed545567857800cc5b0a00
3
+ size 53679833
vision.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79e39bb255e386f5f8fd60702d8db32535bf625261787e05b8a75806b936e4f4
3
+ size 24370369