sachin commited on
Commit
9d7268a
1 Parent(s): c7a14ad

Converted models to transformers standard

Browse files
Files changed (1) hide show
  1. src/models.py +67 -87
src/models.py CHANGED
@@ -1,11 +1,14 @@
1
- import dataclasses
2
- import json
3
-
4
  import timm
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  import transformers
 
 
 
 
9
 
10
 
11
  class Projection(nn.Module):
@@ -37,105 +40,82 @@ def mean_pooling(
37
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float()
38
  return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp(
39
  input_mask_expanded.sum(1), min=1e-9
40
- )
41
 
42
 
43
- class TextEncoder(nn.Module):
44
- def __init__(
45
- self,
46
- base: nn.Module,
47
- d_in: int,
48
- d_out: int,
49
- n_projection_layers: int,
50
- cls_token: bool = False,
51
- ):
52
- super().__init__()
53
- self.base = base
54
- self.cls_token = cls_token
55
- self.projection = projection_layers(d_in, d_out, n_projection_layers)
56
- self.base.eval()
57
- for p in self.base.parameters():
58
- p.requires_grad = False
59
-
60
- def forward(self, x):
61
  out = self.base(**x).last_hidden_state
62
- if self.cls_token:
63
  out = out[:, 0] # get CLS token output
64
  else:
65
- out = mean_pooling(out, x["attention_mask"])
66
 
67
  projected_vec = self.projection(out)
68
  return F.normalize(projected_vec, dim=-1)
69
 
70
 
71
- class VisionEncoder(nn.Module):
72
- def __init__(self, base: nn.Module, d_in: int, d_out: int, n_projection_layers: int):
73
- super().__init__()
74
- self.base = base
75
- self.projection = projection_layers(d_in, d_out, n_projection_layers)
 
 
 
 
 
 
76
 
77
- self.base.eval()
78
- for p in self.base.parameters():
79
- p.requires_grad = False
80
 
81
- def forward(self, x):
82
  projected_vec = self.projection(self.base(x))
83
  return F.normalize(projected_vec, dim=-1)
84
 
85
 
86
- class Tokenizer:
87
- def __init__(self, tokenizer, max_len: int) -> None:
88
- self.tokenizer = tokenizer
89
- self.max_len = max_len
90
 
91
- def __call__(self, x: str) -> transformers.AutoTokenizer:
92
- return self.tokenizer(
93
- x, max_length=self.max_len, truncation=True, padding=True, return_tensors="pt"
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- def decode(self, x: dict[str, torch.LongTensor]) -> list[str]:
97
- return [
98
- self.tokenizer.decode(sentence[:sentence_len])
99
- for sentence, sentence_len in zip(x["input_ids"], x["attention_mask"].sum(axis=-1))
100
- ]
101
-
102
-
103
- @dataclasses.dataclass(frozen=True)
104
- class CLIPConfig:
105
- cls_token: bool = True
106
- n_projection_layers: int = 3
107
- embed_dims: int = 512
108
- vision_model: str = "edgenext_small"
109
- text_model: str = "microsoft/xtremedistil-l6-h256-uncased"
110
- max_len: int = 128
111
-
112
-
113
- def get_model():
114
- with open("./clip_config.json", "r") as f:
115
- config = CLIPConfig(**json.load(f))
116
-
117
- # load text model and tokenizer
118
- text_config = transformers.AutoConfig.from_pretrained("./text_model_config/")
119
- text_base = transformers.AutoModel.from_config(text_config)
120
- tokenizer = Tokenizer(
121
- transformers.AutoTokenizer.from_pretrained("./tokenizer/"), config.max_len
122
- )
123
- text_encoder = TextEncoder(
124
- text_base,
125
- text_base.config.hidden_size,
126
- config.embed_dims,
127
- config.n_projection_layers,
128
- config.cls_token,
129
- )
130
- text_encoder.load_state_dict(torch.load("./text.ckpt", map_location=torch.device("cpu")))
131
-
132
- # load vision model and image transform
133
- image_base = timm.create_model(config.vision_model, num_classes=0)
134
- timm_config = timm.data.resolve_data_config({}, model=image_base)
135
- transform = timm.data.transforms_factory.create_transform(**timm_config)
136
- vision_encoder = VisionEncoder(
137
- image_base, image_base.num_features, config.embed_dims, config.n_projection_layers
138
- )
139
- vision_encoder.load_state_dict(torch.load("./vision.ckpt", map_location=torch.device("cpu")))
140
-
141
- return text_encoder, tokenizer, vision_encoder, transform
 
1
+ from PIL import Image
 
 
2
  import timm
3
+ from timm import data
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import transformers
8
+ from transformers import PreTrainedModel
9
+
10
+ from src.config import TinyCLIPConfig, TinyCLIPTextConfig, TinyCLIPVisionConfig
11
+ from src import loss
12
 
13
 
14
  class Projection(nn.Module):
 
40
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float()
41
  return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp(
42
  input_mask_expanded.sum(1), min=1e-9
43
+ ) # type: ignore
44
 
45
 
46
+ class TinyCLIPTextEncoder(PreTrainedModel):
47
+ config_class = TinyCLIPTextConfig
48
+
49
+ def __init__(self, config: TinyCLIPTextConfig):
50
+ super().__init__(config)
51
+ self.base = transformers.AutoModel.from_pretrained(config.text_model)
52
+ self.cls_type = config.cls_type
53
+ self.projection = projection_layers(
54
+ self.base.config.hidden_size, config.embed_dims, config.projection_layers
55
+ )
56
+
57
+ def forward(self, x: dict[str, torch.Tensor]):
 
 
 
 
 
 
58
  out = self.base(**x).last_hidden_state
59
+ if self.cls_type:
60
  out = out[:, 0] # get CLS token output
61
  else:
62
+ out = mean_pooling(out, x["attention_mask"]) # type: ignore
63
 
64
  projected_vec = self.projection(out)
65
  return F.normalize(projected_vec, dim=-1)
66
 
67
 
68
+ class TinyCLIPVisionEncoder(PreTrainedModel):
69
+ config_class = TinyCLIPVisionConfig
70
+
71
+ def __init__(self, config: TinyCLIPVisionConfig):
72
+ super().__init__(config)
73
+ self.base = timm.create_model(config.vision_model, num_classes=0)
74
+ timm_config = data.resolve_data_config({}, model=self.base)
75
+ self.transform = data.transforms_factory.create_transform(**timm_config)
76
+ self.projection = projection_layers(
77
+ self.base.num_features, config.embed_dims, config.projection_layers
78
+ )
79
 
80
+ def forward(self, images: list[Image.Image]):
81
+ x: torch.Tensor = torch.stack([self.transform(image) for image in images]) # type: ignore
 
82
 
 
83
  projected_vec = self.projection(self.base(x))
84
  return F.normalize(projected_vec, dim=-1)
85
 
86
 
87
+ class TinyCLIP(PreTrainedModel):
88
+ config_class = TinyCLIPConfig
 
 
89
 
90
+ def __init__(self, config: TinyCLIPConfig):
91
+ super().__init__(config)
92
+ self.text_encoder = TinyCLIPTextEncoder(config.text_config)
93
+ self.vision_encoder = TinyCLIPVisionEncoder(config.vision_config)
94
+
95
+ if config.freeze_text_base:
96
+ self.text_encoder.base.eval()
97
+ for param in self.text_encoder.parameters():
98
+ param.requires_grad = False
99
+
100
+ if config.freeze_vision_base:
101
+ self.vision_encoder.base.eval()
102
+ for param in self.vision_encoder.parameters():
103
+ param.requires_grad = False
104
+
105
+ self.loss_fn = loss.get_loss(config.loss_type)
106
+
107
+ def forward(
108
+ self,
109
+ text_input: dict[str, torch.Tensor],
110
+ vision_input: list[Image.Image],
111
+ return_loss: bool = False,
112
+ ) -> dict[str, torch.Tensor]:
113
+ text_output = self.text_encoder(text_input)
114
+ vision_output = self.vision_encoder(vision_input)
115
+
116
+ out = {"text_output": text_output, "vision_output": vision_output}
117
+
118
+ if return_loss:
119
+ out["loss"] = self.loss_fn(vision_output, text_output)
120
 
121
+ return out