sonoisa commited on
Commit
390f508
1 Parent(s): 148e94d

Update all

Browse files
app.py CHANGED
@@ -14,7 +14,7 @@ import pyminizip
14
  import transformers
15
  from transformers import BertJapaneseTokenizer, BertModel
16
  from huggingface_hub import hf_hub_download, snapshot_download
17
- # from PIL import Image
18
 
19
 
20
  def unicode_normalize(cls, s):
@@ -172,135 +172,135 @@ class ClipTextModel(nn.Module):
172
  torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin"))
173
 
174
 
175
- # class ClipVisionModel(nn.Module):
176
- # def __init__(self, model_name_or_path, device=None):
177
- # super(ClipVisionModel, self).__init__()
178
-
179
- # if os.path.exists(model_name_or_path):
180
- # # load from file system
181
- # visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin"))
182
- # else:
183
- # # download from the Hugging Face model hub
184
- # filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin")
185
- # visual_projection_state_dict = torch.load(filename)
186
-
187
- # self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path)
188
- # config = self.model.config
189
-
190
- # self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path)
191
-
192
- # vision_embed_dim = config.hidden_size
193
- # projection_dim = 512
194
-
195
- # self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
196
- # self.visual_projection.load_state_dict(visual_projection_state_dict)
197
-
198
- # self.eval()
199
-
200
- # if device is None:
201
- # device = "cuda" if torch.cuda.is_available() else "cpu"
202
- # self.device = torch.device(device)
203
- # self.to(self.device)
204
-
205
- # def forward(
206
- # self,
207
- # pixel_values=None,
208
- # output_attentions=None,
209
- # output_hidden_states=None,
210
- # return_dict=None,
211
- # ):
212
- # output_states = self.model(
213
- # pixel_values=pixel_values,
214
- # output_attentions=output_attentions,
215
- # output_hidden_states=output_hidden_states,
216
- # return_dict=return_dict,
217
- # )
218
- # image_embeds = self.visual_projection(output_states[1])
219
-
220
- # return image_embeds
221
-
222
- # @torch.no_grad()
223
- # def encode_image(self, images, batch_size=8):
224
- # self.eval()
225
- # all_embeddings = []
226
- # iterator = range(0, len(images), batch_size)
227
- # for batch_idx in iterator:
228
- # batch = images[batch_idx:batch_idx + batch_size]
229
-
230
- # encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device)
231
- # model_output = self(**encoded_input)
232
- # image_embeddings = model_output.cpu()
233
-
234
- # all_embeddings.extend(image_embeddings)
235
-
236
- # # return torch.stack(all_embeddings).numpy()
237
- # return torch.stack(all_embeddings)
238
-
239
- # @staticmethod
240
- # def remove_alpha_channel(image):
241
- # image.convert("RGBA")
242
- # alpha = image.convert('RGBA').split()[-1]
243
- # background = Image.new("RGBA", image.size, (255, 255, 255))
244
- # background.paste(image, mask=alpha)
245
- # image = background.convert("RGB")
246
- # return image
247
-
248
- # def save(self, output_dir):
249
- # self.model.save_pretrained(output_dir)
250
- # self.feature_extractor.save_pretrained(output_dir)
251
- # torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin"))
252
-
253
-
254
- # class ClipModel(nn.Module):
255
- # def __init__(self, model_name_or_path, device=None):
256
- # super(ClipModel, self).__init__()
257
-
258
- # if os.path.exists(model_name_or_path):
259
- # # load from file system
260
- # repo_dir = model_name_or_path
261
- # else:
262
- # # download from the Hugging Face model hub
263
- # repo_dir = snapshot_download(model_name_or_path)
264
-
265
- # self.text_model = ClipTextModel(repo_dir, device=device)
266
- # self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device)
267
-
268
- # with torch.no_grad():
269
- # logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
270
- # logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu())
271
- # self.logit_scale = logit_scale
272
-
273
- # self.eval()
274
-
275
- # if device is None:
276
- # device = "cuda" if torch.cuda.is_available() else "cpu"
277
- # self.device = torch.device(device)
278
- # self.to(self.device)
279
-
280
- # def forward(self, pixel_values, input_ids, attention_mask, token_type_ids):
281
- # image_features = self.vision_model(pixel_values=pixel_values)
282
- # text_features = self.text_model(input_ids=input_ids,
283
- # attention_mask=attention_mask,
284
- # token_type_ids=token_type_ids)[0]
285
-
286
- # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
287
- # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
288
-
289
- # logit_scale = self.logit_scale.exp()
290
- # logits_per_image = logit_scale * image_features @ text_features.t()
291
- # logits_per_text = logits_per_image.t()
292
-
293
- # return logits_per_image, logits_per_text
294
-
295
- # def save(self, output_dir):
296
- # torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin"))
297
- # self.text_model.save(output_dir)
298
- # self.vision_model.save(os.path.join(output_dir, "vision_model"))
299
-
300
-
301
- class DummyClipModel:
302
- def __init__(self, text_model):
303
- self.text_model = text_model
304
 
305
  def encode_text(text, model):
306
  text = normalize_text(text)
@@ -308,10 +308,10 @@ def encode_text(text, model):
308
  return text_embedding
309
 
310
 
311
- # def encode_image(image_filename, model):
312
- # image = Image.open(image_filename)
313
- # image_embedding = model.vision_model.encode_image([image]).numpy()
314
- # return image_embedding
315
 
316
 
317
  st.title("いらすと検索(日本語CLIPゼロショット)")
@@ -321,30 +321,31 @@ if "model" not in st.session_state:
321
  description_text.text("日本語CLIPモデル読み込み中... ")
322
  device = "cuda" if torch.cuda.is_available() else "cpu"
323
  text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
324
- # model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
325
- model = DummyClipModel(text_model)
326
  st.session_state.model = model
327
 
328
  print("extract dataset")
329
  pyminizip.uncompress(
330
- "clip_zeroshot_irasuto_image_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1
331
  )
332
 
333
  print("loading dataset")
334
- df = pq.read_table("clip_zeroshot_irasuto_image_items_20210224.parquet",
335
- columns=["page", "description", "image_url", "image_vector"]).to_pandas()
336
- st.session_state.df = df
337
 
338
- # sentence_vectors = np.stack(df["sentence_vector"])
339
  image_vectors = np.stack(df["image_vector"])
340
- # st.session_state.sentence_vectors = sentence_vectors
 
 
341
  st.session_state.image_vectors = image_vectors
342
 
343
  print("finished loading model and dataset")
344
 
345
  model = st.session_state.model
346
  df = st.session_state.df
347
- # sentence_vectors = st.session_state.sentence_vectors
348
  image_vectors = st.session_state.image_vectors
349
 
350
  description_text.text("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。")
14
  import transformers
15
  from transformers import BertJapaneseTokenizer, BertModel
16
  from huggingface_hub import hf_hub_download, snapshot_download
17
+ from PIL import Image
18
 
19
 
20
  def unicode_normalize(cls, s):
172
  torch.save(self.output_linear.state_dict(), os.path.join(output_dir, "output_linear.bin"))
173
 
174
 
175
+ class ClipVisionModel(nn.Module):
176
+ def __init__(self, model_name_or_path, device=None):
177
+ super(ClipVisionModel, self).__init__()
178
+
179
+ if os.path.exists(model_name_or_path):
180
+ # load from file system
181
+ visual_projection_state_dict = torch.load(os.path.join(model_name_or_path, "visual_projection.bin"))
182
+ else:
183
+ # download from the Hugging Face model hub
184
+ filename = hf_hub_download(repo_id=model_name_or_path, filename="visual_projection.bin")
185
+ visual_projection_state_dict = torch.load(filename)
186
+
187
+ self.model = transformers.CLIPVisionModel.from_pretrained(model_name_or_path)
188
+ config = self.model.config
189
+
190
+ self.feature_extractor = transformers.CLIPFeatureExtractor.from_pretrained(model_name_or_path)
191
+
192
+ vision_embed_dim = config.hidden_size
193
+ projection_dim = 512
194
+
195
+ self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
196
+ self.visual_projection.load_state_dict(visual_projection_state_dict)
197
+
198
+ self.eval()
199
+
200
+ if device is None:
201
+ device = "cuda" if torch.cuda.is_available() else "cpu"
202
+ self.device = torch.device(device)
203
+ self.to(self.device)
204
+
205
+ def forward(
206
+ self,
207
+ pixel_values=None,
208
+ output_attentions=None,
209
+ output_hidden_states=None,
210
+ return_dict=None,
211
+ ):
212
+ output_states = self.model(
213
+ pixel_values=pixel_values,
214
+ output_attentions=output_attentions,
215
+ output_hidden_states=output_hidden_states,
216
+ return_dict=return_dict,
217
+ )
218
+ image_embeds = self.visual_projection(output_states[1])
219
+
220
+ return image_embeds
221
+
222
+ @torch.no_grad()
223
+ def encode_image(self, images, batch_size=8):
224
+ self.eval()
225
+ all_embeddings = []
226
+ iterator = range(0, len(images), batch_size)
227
+ for batch_idx in iterator:
228
+ batch = images[batch_idx:batch_idx + batch_size]
229
+
230
+ encoded_input = self.feature_extractor(batch, return_tensors="pt").to(self.device)
231
+ model_output = self(**encoded_input)
232
+ image_embeddings = model_output.cpu()
233
+
234
+ all_embeddings.extend(image_embeddings)
235
+
236
+ # return torch.stack(all_embeddings).numpy()
237
+ return torch.stack(all_embeddings)
238
+
239
+ @staticmethod
240
+ def remove_alpha_channel(image):
241
+ image.convert("RGBA")
242
+ alpha = image.convert('RGBA').split()[-1]
243
+ background = Image.new("RGBA", image.size, (255, 255, 255))
244
+ background.paste(image, mask=alpha)
245
+ image = background.convert("RGB")
246
+ return image
247
+
248
+ def save(self, output_dir):
249
+ self.model.save_pretrained(output_dir)
250
+ self.feature_extractor.save_pretrained(output_dir)
251
+ torch.save(self.visual_projection.state_dict(), os.path.join(output_dir, "visual_projection.bin"))
252
+
253
+
254
+ class ClipModel(nn.Module):
255
+ def __init__(self, model_name_or_path, device=None):
256
+ super(ClipModel, self).__init__()
257
+
258
+ if os.path.exists(model_name_or_path):
259
+ # load from file system
260
+ repo_dir = model_name_or_path
261
+ else:
262
+ # download from the Hugging Face model hub
263
+ repo_dir = snapshot_download(model_name_or_path)
264
+
265
+ self.text_model = ClipTextModel(repo_dir, device=device)
266
+ self.vision_model = ClipVisionModel(os.path.join(repo_dir, "vision_model"), device=device)
267
+
268
+ with torch.no_grad():
269
+ logit_scale = nn.Parameter(torch.ones([]) * 2.6592)
270
+ logit_scale.set_(torch.load(os.path.join(repo_dir, "logit_scale.bin")).clone().cpu())
271
+ self.logit_scale = logit_scale
272
+
273
+ self.eval()
274
+
275
+ if device is None:
276
+ device = "cuda" if torch.cuda.is_available() else "cpu"
277
+ self.device = torch.device(device)
278
+ self.to(self.device)
279
+
280
+ def forward(self, pixel_values, input_ids, attention_mask, token_type_ids):
281
+ image_features = self.vision_model(pixel_values=pixel_values)
282
+ text_features = self.text_model(input_ids=input_ids,
283
+ attention_mask=attention_mask,
284
+ token_type_ids=token_type_ids)[0]
285
+
286
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
287
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
288
+
289
+ logit_scale = self.logit_scale.exp()
290
+ logits_per_image = logit_scale * image_features @ text_features.t()
291
+ logits_per_text = logits_per_image.t()
292
+
293
+ return logits_per_image, logits_per_text
294
+
295
+ def save(self, output_dir):
296
+ torch.save(self.logit_scale, os.path.join(output_dir, "logit_scale.bin"))
297
+ self.text_model.save(output_dir)
298
+ self.vision_model.save(os.path.join(output_dir, "vision_model"))
299
+
300
+
301
+ # class DummyClipModel:
302
+ # def __init__(self, text_model):
303
+ # self.text_model = text_model
304
 
305
  def encode_text(text, model):
306
  text = normalize_text(text)
308
  return text_embedding
309
 
310
 
311
+ def encode_image(image_filename, model):
312
+ image = Image.open(image_filename)
313
+ image_embedding = model.vision_model.encode_image([image]).numpy()
314
+ return image_embedding
315
 
316
 
317
  st.title("いらすと検索(日本語CLIPゼロショット)")
321
  description_text.text("日本語CLIPモデル読み込み中... ")
322
  device = "cuda" if torch.cuda.is_available() else "cpu"
323
  text_model = ClipTextModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
324
+ model = ClipModel("sonoisa/clip-vit-b-32-japanese-v1", device=device)
325
+ # model = DummyClipModel(text_model)
326
  st.session_state.model = model
327
 
328
  print("extract dataset")
329
  pyminizip.uncompress(
330
+ "clip_zeroshot_irasuto_items_20210224.pq.zip", st.secrets["ZIP_PASSWORD"], None, 1
331
  )
332
 
333
  print("loading dataset")
334
+ df = pq.read_table("clip_zeroshot_irasuto_items_20210224.parquet",
335
+ columns=["page", "description", "image_url", "sentence_vector", "image_vector"]).to_pandas()
 
336
 
337
+ sentence_vectors = np.stack(df["sentence_vector"])
338
  image_vectors = np.stack(df["image_vector"])
339
+ st.session_state.sentence_vectors = sentence_vectors
340
+
341
+ st.session_state.df = df
342
  st.session_state.image_vectors = image_vectors
343
 
344
  print("finished loading model and dataset")
345
 
346
  model = st.session_state.model
347
  df = st.session_state.df
348
+ sentence_vectors = st.session_state.sentence_vectors
349
  image_vectors = st.session_state.image_vectors
350
 
351
  description_text.text("日本語CLIPモデル(ゼロショット)を用いて、説明文の意味が近い「いらすとや」画像を検索します。\nキーワードを列挙するよりも、自然な文章を入力した方が精度よく検索できます。\n画像は必ずリンク先の「いらすとや」さんのページを開き、そこからダウンロードしてください。")
clip_zeroshot_irasuto_image_items_20210224.pq.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2f602399369a485f1586b7ca04e8ae096868ecce85527928671b08bf5e80c200
3
- size 54262882
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62eabf2fd3664a3ddfe29bb7ee59027fa37a34a1d05a9704f09ac363ad5acb2f
3
+ size 72554784
clip_zeroshot_irasuto_items_20210224.pq.zip CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:321d909ba0f92425a5107ad26a6d97dc4f7601b2b4f22ab020199f2ba2237ce7
3
- size 104296063
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3059351ecc86353c53ba25f7cb5e74db0e55b1ba5257402970a20fd04158b5f1
3
+ size 122826331
requirements.txt CHANGED
@@ -4,4 +4,4 @@ pyminizip
4
  fugashi
5
  ipadic
6
  scipy
7
- #pillow==7.1.2
4
  fugashi
5
  ipadic
6
  scipy
7
+ pillow==7.1.2