VikTsrv commited on
Commit
61c75a3
·
1 Parent(s): dce3e8c
Files changed (1) hide show
  1. app.py +282 -123
app.py CHANGED
@@ -1,3 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # # import gradio as gr
2
  # # import torch
3
  # # import torch.nn as nn
@@ -7,16 +161,19 @@
7
  # # import easyocr
8
  # # import json
9
  # # import os
 
10
 
11
- # # import spaces # добавьте в начале
12
 
13
- # # @spaces.GPU(duration=60) # добавьте перед predict
14
- # # def predict_demo(image, caption_text=""):
15
- # # # ... ваш код
16
  # # # ======================
17
- # # # ФИКСИРУЕМ ПУТИ (важно для Spaces!)
 
 
 
 
 
 
18
  # # # ======================
19
- # # # Модели и веса лежат в той же папке, что и app.py
20
  # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
 
22
  # # # Загрузка названий классов
@@ -28,74 +185,82 @@
28
 
29
 
30
  # # # ======================
31
- # # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием)
32
  # # # ======================
33
- # # @gr.cache_resource
34
- # # def load_models():
35
- # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- # # print(f"Using device: {DEVICE}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # # # Визуальный энкодер
39
- # # visual = models.resnet50(weights=None)
40
- # # visual.fc = nn.Identity()
41
- # # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE))
 
 
 
 
42
  # # visual.to(DEVICE)
43
  # # visual.eval()
44
  # # for p in visual.parameters():
45
  # # p.requires_grad = False
46
 
47
- # # # Текстовые энкодеры
48
  # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
49
- # # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
50
- # # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
 
 
51
 
52
  # # for p in ocr_encoder.parameters():
53
  # # p.requires_grad = False
54
  # # for p in caption_encoder.parameters():
55
  # # p.requires_grad = False
56
 
57
- # # # Классификатор
58
- # # class ConcatFusionModel(nn.Module):
59
- # # def __init__(self, num_classes, dropout=0.3):
60
- # # super().__init__()
61
- # # self.classifier = nn.Sequential(
62
- # # nn.Linear(2048 + 312 + 312, 512),
63
- # # nn.BatchNorm1d(512),
64
- # # nn.ReLU(),
65
- # # nn.Dropout(dropout),
66
- # # nn.Linear(512, num_classes)
67
- # # )
68
-
69
- # # def forward(self, v, ocr, cap):
70
- # # x = torch.cat([v, ocr, cap], dim=1)
71
- # # return self.classifier(x)
72
-
73
  # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
74
- # # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
 
75
  # # model.to(DEVICE)
76
  # # model.eval()
77
 
78
  # # # EasyOCR
79
  # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
80
 
81
- # # # Трансформы
82
  # # val_transform = transforms.Compose([
83
  # # transforms.Resize(256),
84
  # # transforms.CenterCrop(224),
85
  # # transforms.ToTensor(),
86
- # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
87
  # # ])
88
 
89
- # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE
90
 
91
 
92
- # # # Загружаем всё при старте
93
- # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models()
94
 
95
 
96
  # # # ======================
97
  # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
98
  # # # ======================
 
99
  # # def predict(image, caption_text=""):
100
  # # image = image.convert("RGB")
101
 
@@ -110,19 +275,23 @@
110
  # # v = torch.flatten(v, 1)
111
 
112
  # # # OCR encode
113
- # # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
114
- # # ocr_ids = ocr_enc["input_ids"].to(DEVICE)
115
- # # ocr_mask = ocr_enc["attention_mask"].to(DEVICE)
116
  # # with torch.no_grad():
117
- # # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask)
 
 
 
118
  # # ocr = ocr_out.last_hidden_state[:, 0]
119
 
120
  # # # Caption encode
121
- # # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
122
- # # cap_ids = cap_enc["input_ids"].to(DEVICE)
123
- # # cap_mask = cap_enc["attention_mask"].to(DEVICE)
124
  # # with torch.no_grad():
125
- # # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask)
 
 
 
126
  # # cap = cap_out.last_hidden_state[:, 0]
127
 
128
  # # # Предсказание
@@ -140,10 +309,11 @@
140
  # # demo = gr.Interface(
141
  # # fn=predict,
142
  # # inputs=[
143
- # # gr.Image(type="pil", label="Загрузите изображение"),
144
- # # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...")
 
145
  # # ],
146
- # # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"),
147
  # # title="Мультимодальный классификатор контента",
148
  # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
149
  # # )
@@ -151,7 +321,6 @@
151
  # # if __name__ == "__main__":
152
  # # demo.launch()
153
 
154
-
155
  # import gradio as gr
156
  # import torch
157
  # import torch.nn as nn
@@ -185,7 +354,7 @@
185
 
186
 
187
  # # ======================
188
- # # ОПРЕДЕЛЕНИЕ МОДЕЛИ
189
  # # ======================
190
  # class ConcatFusionModel(nn.Module):
191
  # def __init__(self, num_classes, dropout=0.3):
@@ -208,53 +377,52 @@
208
 
209
 
210
  # # ======================
211
- # # ЗАГРУЗКА МОДЕЛЕЙ
212
  # # ======================
213
- # @gr.cache
214
- # def load_models():
215
- # # Визуальный энкодер (загружаем предобученный из torchvision)
216
- # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
217
- # visual.fc = nn.Identity() # убираем классификатор
218
- # visual.to(DEVICE)
219
- # visual.eval()
220
- # for p in visual.parameters():
221
- # p.requires_grad = False
222
-
223
- # # Текстовые энкодеры (загружаем предобученные из Hugging Face)
224
- # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
225
- # ocr_encoder = AutoModel.from_pretrained(
226
- # "cointegrated/rubert-tiny2").to(DEVICE).eval()
227
- # caption_encoder = AutoModel.from_pretrained(
228
- # "cointegrated/rubert-tiny2").to(DEVICE).eval()
229
-
230
- # for p in ocr_encoder.parameters():
231
- # p.requires_grad = False
232
- # for p in caption_encoder.parameters():
233
- # p.requires_grad = False
234
-
235
- # # Классификационная голова (обученная)
236
- # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
237
- # model.load_state_dict(torch.load(os.path.join(
238
- # BASE_DIR, "concat_model.pth"), map_location=DEVICE))
239
- # model.to(DEVICE)
240
- # model.eval()
241
-
242
- # # EasyOCR
243
- # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
244
-
245
- # # Трансформы для изображений
246
- # val_transform = transforms.Compose([
247
- # transforms.Resize(256),
248
- # transforms.CenterCrop(224),
249
- # transforms.ToTensor(),
250
- # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
251
- # 0.229, 0.224, 0.225]),
252
- # ])
253
-
254
- # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
255
-
256
-
257
- # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models()
258
 
259
 
260
  # # ======================
@@ -332,8 +500,6 @@ import json
332
  import os
333
  import numpy as np
334
 
335
- import spaces
336
-
337
  # ======================
338
  # УСТАНОВКА УСТРОЙСТВА
339
  # ======================
@@ -354,9 +520,9 @@ NUM_CLASSES = len(id2label)
354
 
355
 
356
  # ======================
357
- # ОПРЕДЕЛЕНИЕ МОДЕЛИ (НАРУЖУ, НЕ ВНУТРИ load_models!)
358
  # ======================
359
- class ConcatFusionModel(nn.Module):
360
  def __init__(self, num_classes, dropout=0.3):
361
  super().__init__()
362
  self.classifier = nn.Sequential(
@@ -377,9 +543,9 @@ class ConcatFusionModel(nn.Module):
377
 
378
 
379
  # ======================
380
- # ЗАГРУЗКА МОДЕЛЕЙ (без декоратора, глобально)
381
  # ======================
382
- # Визуальный энкодер
383
  visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
384
  visual.fc = nn.Identity()
385
  visual.to(DEVICE)
@@ -389,20 +555,18 @@ for p in visual.parameters():
389
 
390
  # Текстовые энкодеры
391
  tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
392
- ocr_encoder = AutoModel.from_pretrained(
393
- "cointegrated/rubert-tiny2").to(DEVICE).eval()
394
- caption_encoder = AutoModel.from_pretrained(
395
- "cointegrated/rubert-tiny2").to(DEVICE).eval()
396
 
397
  for p in ocr_encoder.parameters():
398
  p.requires_grad = False
399
  for p in caption_encoder.parameters():
400
  p.requires_grad = False
401
 
402
- # Классификационная голова
403
- model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
404
- model.load_state_dict(torch.load(os.path.join(
405
- BASE_DIR, "concat_model.pth"), map_location=DEVICE))
406
  model.to(DEVICE)
407
  model.eval()
408
 
@@ -414,15 +578,13 @@ val_transform = transforms.Compose([
414
  transforms.Resize(256),
415
  transforms.CenterCrop(224),
416
  transforms.ToTensor(),
417
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
418
- std=[0.229, 0.224, 0.225]),
419
  ])
420
 
421
 
422
  # ======================
423
  # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
424
  # ======================
425
- @spaces.GPU(duration=60)
426
  def predict(image, caption_text=""):
427
  image = image.convert("RGB")
428
 
@@ -437,8 +599,7 @@ def predict(image, caption_text=""):
437
  v = torch.flatten(v, 1)
438
 
439
  # OCR encode
440
- ocr_enc = tokenizer(ocr_text, truncation=True,
441
- padding="max_length", max_length=64, return_tensors="pt")
442
  with torch.no_grad():
443
  ocr_out = ocr_encoder(
444
  input_ids=ocr_enc["input_ids"].to(DEVICE),
@@ -447,8 +608,7 @@ def predict(image, caption_text=""):
447
  ocr = ocr_out.last_hidden_state[:, 0]
448
 
449
  # Caption encode
450
- cap_enc = tokenizer(caption_text, truncation=True,
451
- padding="max_length", max_length=128, return_tensors="pt")
452
  with torch.no_grad():
453
  cap_out = caption_encoder(
454
  input_ids=cap_enc["input_ids"].to(DEVICE),
@@ -472,8 +632,7 @@ demo = gr.Interface(
472
  fn=predict,
473
  inputs=[
474
  gr.Image(type="pil", label="📸 Загрузите изображение"),
475
- gr.Textbox(label="📝 Подпись (необязательно)",
476
- placeholder="Введите текст подписи...")
477
  ],
478
  outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
479
  title="Мультимодальный классификатор контента",
@@ -481,4 +640,4 @@ demo = gr.Interface(
481
  )
482
 
483
  if __name__ == "__main__":
484
- demo.launch()
 
1
+ # # # import gradio as gr
2
+ # # # import torch
3
+ # # # import torch.nn as nn
4
+ # # # from torchvision import models, transforms
5
+ # # # from PIL import Image
6
+ # # # from transformers import AutoModel, AutoTokenizer
7
+ # # # import easyocr
8
+ # # # import json
9
+ # # # import os
10
+
11
+ # # # import spaces # добавьте в начале
12
+
13
+ # # # @spaces.GPU(duration=60) # добавьте перед predict
14
+ # # # def predict_demo(image, caption_text=""):
15
+ # # # # ... ваш код
16
+ # # # # ======================
17
+ # # # # ФИКСИРУЕМ ПУТИ (важно для Spaces!)
18
+ # # # # ======================
19
+ # # # # Модели и веса лежат в той же папке, что и app.py
20
+ # # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
+
22
+ # # # # Загрузка названий классов
23
+ # # # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
24
+ # # # id2label = json.load(f)
25
+ # # # id2label = {int(k): v for k, v in id2label.items()}
26
+
27
+ # # # NUM_CLASSES = len(id2label)
28
+
29
+
30
+ # # # # ======================
31
+ # # # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием)
32
+ # # # # ======================
33
+ # # # @gr.cache_resource
34
+ # # # def load_models():
35
+ # # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # # # print(f"Using device: {DEVICE}")
37
+
38
+ # # # # Визуальный энкодер
39
+ # # # visual = models.resnet50(weights=None)
40
+ # # # visual.fc = nn.Identity()
41
+ # # # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE))
42
+ # # # visual.to(DEVICE)
43
+ # # # visual.eval()
44
+ # # # for p in visual.parameters():
45
+ # # # p.requires_grad = False
46
+
47
+ # # # # Текстовые энкодеры
48
+ # # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
49
+ # # # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
50
+ # # # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
51
+
52
+ # # # for p in ocr_encoder.parameters():
53
+ # # # p.requires_grad = False
54
+ # # # for p in caption_encoder.parameters():
55
+ # # # p.requires_grad = False
56
+
57
+ # # # # Классификатор
58
+ # # # class ConcatFusionModel(nn.Module):
59
+ # # # def __init__(self, num_classes, dropout=0.3):
60
+ # # # super().__init__()
61
+ # # # self.classifier = nn.Sequential(
62
+ # # # nn.Linear(2048 + 312 + 312, 512),
63
+ # # # nn.BatchNorm1d(512),
64
+ # # # nn.ReLU(),
65
+ # # # nn.Dropout(dropout),
66
+ # # # nn.Linear(512, num_classes)
67
+ # # # )
68
+
69
+ # # # def forward(self, v, ocr, cap):
70
+ # # # x = torch.cat([v, ocr, cap], dim=1)
71
+ # # # return self.classifier(x)
72
+
73
+ # # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
74
+ # # # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
75
+ # # # model.to(DEVICE)
76
+ # # # model.eval()
77
+
78
+ # # # # EasyOCR
79
+ # # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
80
+
81
+ # # # # Трансформы
82
+ # # # val_transform = transforms.Compose([
83
+ # # # transforms.Resize(256),
84
+ # # # transforms.CenterCrop(224),
85
+ # # # transforms.ToTensor(),
86
+ # # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
87
+ # # # ])
88
+
89
+ # # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE
90
+
91
+
92
+ # # # # Загружаем всё при старте
93
+ # # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models()
94
+
95
+
96
+ # # # # ======================
97
+ # # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
98
+ # # # # ======================
99
+ # # # def predict(image, caption_text=""):
100
+ # # # image = image.convert("RGB")
101
+
102
+ # # # # OCR
103
+ # # # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
104
+ # # # ocr_text = " ".join(ocr_result) if ocr_result else ""
105
+
106
+ # # # # Image
107
+ # # # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
108
+ # # # with torch.no_grad():
109
+ # # # v = visual(image_tensor)
110
+ # # # v = torch.flatten(v, 1)
111
+
112
+ # # # # OCR encode
113
+ # # # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
114
+ # # # ocr_ids = ocr_enc["input_ids"].to(DEVICE)
115
+ # # # ocr_mask = ocr_enc["attention_mask"].to(DEVICE)
116
+ # # # with torch.no_grad():
117
+ # # # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask)
118
+ # # # ocr = ocr_out.last_hidden_state[:, 0]
119
+
120
+ # # # # Caption encode
121
+ # # # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
122
+ # # # cap_ids = cap_enc["input_ids"].to(DEVICE)
123
+ # # # cap_mask = cap_enc["attention_mask"].to(DEVICE)
124
+ # # # with torch.no_grad():
125
+ # # # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask)
126
+ # # # cap = cap_out.last_hidden_state[:, 0]
127
+
128
+ # # # # Предсказание
129
+ # # # with torch.no_grad():
130
+ # # # logits = model(v, ocr, cap)
131
+ # # # probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
132
+
133
+ # # # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
134
+ # # # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
135
+
136
+
137
+ # # # # ======================
138
+ # # # # GRADIO ИНТЕРФЕЙС
139
+ # # # # ======================
140
+ # # # demo = gr.Interface(
141
+ # # # fn=predict,
142
+ # # # inputs=[
143
+ # # # gr.Image(type="pil", label="Загрузите изображение"),
144
+ # # # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...")
145
+ # # # ],
146
+ # # # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"),
147
+ # # # title="Мультимодальный классификатор контента",
148
+ # # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
149
+ # # # )
150
+
151
+ # # # if __name__ == "__main__":
152
+ # # # demo.launch()
153
+
154
+
155
  # # import gradio as gr
156
  # # import torch
157
  # # import torch.nn as nn
 
161
  # # import easyocr
162
  # # import json
163
  # # import os
164
+ # # import numpy as np
165
 
166
+ # # import spaces
167
 
 
 
 
168
  # # # ======================
169
+ # # # УСТАНОВКА УСТРОЙСТВА
170
+ # # # ======================
171
+ # # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172
+ # # print(f"Using device: {DEVICE}")
173
+
174
+ # # # ======================
175
+ # # # ПУТИ
176
  # # # ======================
 
177
  # # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
178
 
179
  # # # Загрузка названий классов
 
185
 
186
 
187
  # # # ======================
188
+ # # # ОПРЕДЕЛЕНИЕ МОДЕЛИ
189
  # # # ======================
190
+ # # class ConcatFusionModel(nn.Module):
191
+ # # def __init__(self, num_classes, dropout=0.3):
192
+ # # super().__init__()
193
+ # # self.classifier = nn.Sequential(
194
+ # # nn.Linear(2048 + 312 + 312, 512),
195
+ # # nn.BatchNorm1d(512),
196
+ # # nn.ReLU(),
197
+ # # nn.Dropout(dropout),
198
+ # # nn.Linear(512, 256),
199
+ # # nn.BatchNorm1d(256),
200
+ # # nn.ReLU(),
201
+ # # nn.Dropout(0.3),
202
+ # # nn.Linear(256, num_classes)
203
+ # # )
204
+
205
+ # # def forward(self, v, ocr, cap):
206
+ # # x = torch.cat([v, ocr, cap], dim=1)
207
+ # # return self.classifier(x)
208
+
209
 
210
+ # # # ======================
211
+ # # # ЗАГРУЗКА МОДЕЛЕЙ
212
+ # # # ======================
213
+ # # @gr.cache
214
+ # # def load_models():
215
+ # # # Визуальный энкодер (загружаем предобученный из torchvision)
216
+ # # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
217
+ # # visual.fc = nn.Identity() # убираем классификатор
218
  # # visual.to(DEVICE)
219
  # # visual.eval()
220
  # # for p in visual.parameters():
221
  # # p.requires_grad = False
222
 
223
+ # # # Текстовые энкодеры (загружаем предобученные из Hugging Face)
224
  # # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
225
+ # # ocr_encoder = AutoModel.from_pretrained(
226
+ # # "cointegrated/rubert-tiny2").to(DEVICE).eval()
227
+ # # caption_encoder = AutoModel.from_pretrained(
228
+ # # "cointegrated/rubert-tiny2").to(DEVICE).eval()
229
 
230
  # # for p in ocr_encoder.parameters():
231
  # # p.requires_grad = False
232
  # # for p in caption_encoder.parameters():
233
  # # p.requires_grad = False
234
 
235
+ # # # Классификационная голова (обученная)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
237
+ # # model.load_state_dict(torch.load(os.path.join(
238
+ # # BASE_DIR, "concat_model.pth"), map_location=DEVICE))
239
  # # model.to(DEVICE)
240
  # # model.eval()
241
 
242
  # # # EasyOCR
243
  # # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
244
 
245
+ # # # Трансформы для изображений
246
  # # val_transform = transforms.Compose([
247
  # # transforms.Resize(256),
248
  # # transforms.CenterCrop(224),
249
  # # transforms.ToTensor(),
250
+ # # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
251
+ # # 0.229, 0.224, 0.225]),
252
  # # ])
253
 
254
+ # # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
255
 
256
 
257
+ # # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models()
 
258
 
259
 
260
  # # # ======================
261
  # # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
262
  # # # ======================
263
+ # # @spaces.GPU(duration=60)
264
  # # def predict(image, caption_text=""):
265
  # # image = image.convert("RGB")
266
 
 
275
  # # v = torch.flatten(v, 1)
276
 
277
  # # # OCR encode
278
+ # # ocr_enc = tokenizer(ocr_text, truncation=True,
279
+ # # padding="max_length", max_length=64, return_tensors="pt")
 
280
  # # with torch.no_grad():
281
+ # # ocr_out = ocr_encoder(
282
+ # # input_ids=ocr_enc["input_ids"].to(DEVICE),
283
+ # # attention_mask=ocr_enc["attention_mask"].to(DEVICE)
284
+ # # )
285
  # # ocr = ocr_out.last_hidden_state[:, 0]
286
 
287
  # # # Caption encode
288
+ # # cap_enc = tokenizer(caption_text, truncation=True,
289
+ # # padding="max_length", max_length=128, return_tensors="pt")
 
290
  # # with torch.no_grad():
291
+ # # cap_out = caption_encoder(
292
+ # # input_ids=cap_enc["input_ids"].to(DEVICE),
293
+ # # attention_mask=cap_enc["attention_mask"].to(DEVICE)
294
+ # # )
295
  # # cap = cap_out.last_hidden_state[:, 0]
296
 
297
  # # # Предсказание
 
309
  # # demo = gr.Interface(
310
  # # fn=predict,
311
  # # inputs=[
312
+ # # gr.Image(type="pil", label="📸 Загрузите изображение"),
313
+ # # gr.Textbox(label="📝 Подпись (необязательно)",
314
+ # # placeholder="Введите текст подписи...")
315
  # # ],
316
+ # # outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
317
  # # title="Мультимодальный классификатор контента",
318
  # # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
319
  # # )
 
321
  # # if __name__ == "__main__":
322
  # # demo.launch()
323
 
 
324
  # import gradio as gr
325
  # import torch
326
  # import torch.nn as nn
 
354
 
355
 
356
  # # ======================
357
+ # # ОПРЕДЕЛЕНИЕ МОДЕЛИ (НАРУЖУ, НЕ ВНУТРИ load_models!)
358
  # # ======================
359
  # class ConcatFusionModel(nn.Module):
360
  # def __init__(self, num_classes, dropout=0.3):
 
377
 
378
 
379
  # # ======================
380
+ # # ЗАГРУЗКА МОДЕЛЕЙ (без декоратора, глобально)
381
  # # ======================
382
+ # # Визуальный энкодер
383
+ # visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
384
+ # visual.fc = nn.Identity()
385
+ # visual.to(DEVICE)
386
+ # visual.eval()
387
+ # for p in visual.parameters():
388
+ # p.requires_grad = False
389
+
390
+ # # Текстовые энкодеры
391
+ # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
392
+ # ocr_encoder = AutoModel.from_pretrained(
393
+ # "cointegrated/rubert-tiny2").to(DEVICE).eval()
394
+ # caption_encoder = AutoModel.from_pretrained(
395
+ # "cointegrated/rubert-tiny2").to(DEVICE).eval()
396
+
397
+ # for p in ocr_encoder.parameters():
398
+ # p.requires_grad = False
399
+ # for p in caption_encoder.parameters():
400
+ # p.requires_grad = False
401
+
402
+ # # Классификационная голова
403
+ # # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
404
+ # # model.load_state_dict(torch.load(os.path.join(
405
+ # # BASE_DIR, "concat_model_head.pth"), map_location=DEVICE))
406
+ # # model.to(DEVICE)
407
+ # # model.eval()
408
+ # # В демо-скрипте
409
+ # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
410
+ # head_state = torch.load("best_head_only.pth", map_location=DEVICE)
411
+ # model.load_state_dict(head_state, strict=False) # strict=False позволяет игнорировать отсутствие энкодеров
412
+ # model.to(DEVICE)
413
+ # model.eval()
414
+
415
+ # # EasyOCR
416
+ # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
417
+
418
+ # # Трансформы
419
+ # val_transform = transforms.Compose([
420
+ # transforms.Resize(256),
421
+ # transforms.CenterCrop(224),
422
+ # transforms.ToTensor(),
423
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
424
+ # std=[0.229, 0.224, 0.225]),
425
+ # ])
 
426
 
427
 
428
  # # ======================
 
500
  import os
501
  import numpy as np
502
 
 
 
503
  # ======================
504
  # УСТАНОВКА УСТРОЙСТВА
505
  # ======================
 
520
 
521
 
522
  # ======================
523
+ # ТОЛЬКО ГОЛОВА
524
  # ======================
525
+ class FusionHead(nn.Module):
526
  def __init__(self, num_classes, dropout=0.3):
527
  super().__init__()
528
  self.classifier = nn.Sequential(
 
543
 
544
 
545
  # ======================
546
+ # ЗАГРУЗКА ЭНКОДЕРОВ И ГОЛОВЫ
547
  # ======================
548
+ # Визуальный энкодер (предобученный)
549
  visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
550
  visual.fc = nn.Identity()
551
  visual.to(DEVICE)
 
555
 
556
  # Текстовые энкодеры
557
  tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
558
+ ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
559
+ caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
 
 
560
 
561
  for p in ocr_encoder.parameters():
562
  p.requires_grad = False
563
  for p in caption_encoder.parameters():
564
  p.requires_grad = False
565
 
566
+ # Головаагружаем только веса головы)
567
+ model = FusionHead(NUM_CLASSES, dropout=0.3)
568
+ head_state = torch.load(os.path.join(BASE_DIR, "concat_model_head.pth"), map_location=DEVICE)
569
+ model.load_state_dict(head_state, strict=True) # strict=True, потому что в файле только голова
570
  model.to(DEVICE)
571
  model.eval()
572
 
 
578
  transforms.Resize(256),
579
  transforms.CenterCrop(224),
580
  transforms.ToTensor(),
581
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
582
  ])
583
 
584
 
585
  # ======================
586
  # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
587
  # ======================
 
588
  def predict(image, caption_text=""):
589
  image = image.convert("RGB")
590
 
 
599
  v = torch.flatten(v, 1)
600
 
601
  # OCR encode
602
+ ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
 
603
  with torch.no_grad():
604
  ocr_out = ocr_encoder(
605
  input_ids=ocr_enc["input_ids"].to(DEVICE),
 
608
  ocr = ocr_out.last_hidden_state[:, 0]
609
 
610
  # Caption encode
611
+ cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
 
612
  with torch.no_grad():
613
  cap_out = caption_encoder(
614
  input_ids=cap_enc["input_ids"].to(DEVICE),
 
632
  fn=predict,
633
  inputs=[
634
  gr.Image(type="pil", label="📸 Загрузите изображение"),
635
+ gr.Textbox(label="📝 Подпись (необязательно)", placeholder="Введите текст подписи...")
 
636
  ],
637
  outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
638
  title="Мультимодальный классификатор контента",
 
640
  )
641
 
642
  if __name__ == "__main__":
643
+ demo.launch()