VikTsrv commited on
Commit
5006afe
·
1 Parent(s): 061a1bc

fix app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -152,9 +152,6 @@
152
  # demo.launch()
153
 
154
 
155
-
156
-
157
-
158
  import gradio as gr
159
  import torch
160
  import torch.nn as nn
@@ -213,7 +210,7 @@ class ConcatFusionModel(nn.Module):
213
  # ======================
214
  # ЗАГРУЗКА МОДЕЛЕЙ
215
  # ======================
216
- @gr.cache_resource
217
  def load_models():
218
  # Визуальный энкодер (загружаем предобученный из torchvision)
219
  visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
@@ -225,8 +222,10 @@ def load_models():
225
 
226
  # Текстовые энкодеры (загружаем предобученные из Hugging Face)
227
  tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
228
- ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
229
- caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
 
 
230
 
231
  for p in ocr_encoder.parameters():
232
  p.requires_grad = False
@@ -235,7 +234,8 @@ def load_models():
235
 
236
  # Классификационная голова (обученная)
237
  model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
238
- model.load_state_dict(torch.load(os.path.join(BASE_DIR, "concat_model.pth"), map_location=DEVICE))
 
239
  model.to(DEVICE)
240
  model.eval()
241
 
@@ -247,7 +247,8 @@ def load_models():
247
  transforms.Resize(256),
248
  transforms.CenterCrop(224),
249
  transforms.ToTensor(),
250
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
251
  ])
252
 
253
  return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
@@ -274,7 +275,8 @@ def predict(image, caption_text=""):
274
  v = torch.flatten(v, 1)
275
 
276
  # OCR encode
277
- ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
 
278
  with torch.no_grad():
279
  ocr_out = ocr_encoder(
280
  input_ids=ocr_enc["input_ids"].to(DEVICE),
@@ -283,7 +285,8 @@ def predict(image, caption_text=""):
283
  ocr = ocr_out.last_hidden_state[:, 0]
284
 
285
  # Caption encode
286
- cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
 
287
  with torch.no_grad():
288
  cap_out = caption_encoder(
289
  input_ids=cap_enc["input_ids"].to(DEVICE),
@@ -307,7 +310,8 @@ demo = gr.Interface(
307
  fn=predict,
308
  inputs=[
309
  gr.Image(type="pil", label="📸 Загрузите изображение"),
310
- gr.Textbox(label="📝 Подпись (необязательно)", placeholder="Введите текст подписи...")
 
311
  ],
312
  outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
313
  title="Мультимодальный классификатор контента",
@@ -315,4 +319,4 @@ demo = gr.Interface(
315
  )
316
 
317
  if __name__ == "__main__":
318
- demo.launch()
 
152
  # demo.launch()
153
 
154
 
 
 
 
155
  import gradio as gr
156
  import torch
157
  import torch.nn as nn
 
210
  # ======================
211
  # ЗАГРУЗКА МОДЕЛЕЙ
212
  # ======================
213
+ @gr.cache
214
  def load_models():
215
  # Визуальный энкодер (загружаем предобученный из torchvision)
216
  visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
 
222
 
223
  # Текстовые энкодеры (загружаем предобученные из Hugging Face)
224
  tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
225
+ ocr_encoder = AutoModel.from_pretrained(
226
+ "cointegrated/rubert-tiny2").to(DEVICE).eval()
227
+ caption_encoder = AutoModel.from_pretrained(
228
+ "cointegrated/rubert-tiny2").to(DEVICE).eval()
229
 
230
  for p in ocr_encoder.parameters():
231
  p.requires_grad = False
 
234
 
235
  # Классификационная голова (обученная)
236
  model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
237
+ model.load_state_dict(torch.load(os.path.join(
238
+ BASE_DIR, "concat_model.pth"), map_location=DEVICE))
239
  model.to(DEVICE)
240
  model.eval()
241
 
 
247
  transforms.Resize(256),
248
  transforms.CenterCrop(224),
249
  transforms.ToTensor(),
250
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
251
+ 0.229, 0.224, 0.225]),
252
  ])
253
 
254
  return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
 
275
  v = torch.flatten(v, 1)
276
 
277
  # OCR encode
278
+ ocr_enc = tokenizer(ocr_text, truncation=True,
279
+ padding="max_length", max_length=64, return_tensors="pt")
280
  with torch.no_grad():
281
  ocr_out = ocr_encoder(
282
  input_ids=ocr_enc["input_ids"].to(DEVICE),
 
285
  ocr = ocr_out.last_hidden_state[:, 0]
286
 
287
  # Caption encode
288
+ cap_enc = tokenizer(caption_text, truncation=True,
289
+ padding="max_length", max_length=128, return_tensors="pt")
290
  with torch.no_grad():
291
  cap_out = caption_encoder(
292
  input_ids=cap_enc["input_ids"].to(DEVICE),
 
310
  fn=predict,
311
  inputs=[
312
  gr.Image(type="pil", label="📸 Загрузите изображение"),
313
+ gr.Textbox(label="📝 Подпись (необязательно)",
314
+ placeholder="Введите текст подписи...")
315
  ],
316
  outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
317
  title="Мультимодальный классификатор контента",
 
319
  )
320
 
321
  if __name__ == "__main__":
322
+ demo.launch()