VikTsrv commited on
Commit
061a1bc
·
1 Parent(s): 82a6b44

add files for web demo

Browse files
Files changed (4) hide show
  1. app.py +318 -0
  2. class_names.json +10 -0
  3. concat_model.pth +3 -0
  4. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+ # import torch
3
+ # import torch.nn as nn
4
+ # from torchvision import models, transforms
5
+ # from PIL import Image
6
+ # from transformers import AutoModel, AutoTokenizer
7
+ # import easyocr
8
+ # import json
9
+ # import os
10
+
11
+ # import spaces # добавьте в начале
12
+
13
+ # @spaces.GPU(duration=60) # добавьте перед predict
14
+ # def predict_demo(image, caption_text=""):
15
+ # # ... ваш код
16
+ # # ======================
17
+ # # ФИКСИРУЕМ ПУТИ (важно для Spaces!)
18
+ # # ======================
19
+ # # Модели и веса лежат в той же папке, что и app.py
20
+ # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
21
+
22
+ # # Загрузка названий классов
23
+ # with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
24
+ # id2label = json.load(f)
25
+ # id2label = {int(k): v for k, v in id2label.items()}
26
+
27
+ # NUM_CLASSES = len(id2label)
28
+
29
+
30
+ # # ======================
31
+ # # ЗАГРУЗКА МОДЕЛЕЙ (один раз, с кешированием)
32
+ # # ======================
33
+ # @gr.cache_resource
34
+ # def load_models():
35
+ # DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # print(f"Using device: {DEVICE}")
37
+
38
+ # # Визуальный энкодер
39
+ # visual = models.resnet50(weights=None)
40
+ # visual.fc = nn.Identity()
41
+ # visual.load_state_dict(torch.load(os.path.join(BASE_DIR, "resnet50_encoder.pth"), map_location=DEVICE))
42
+ # visual.to(DEVICE)
43
+ # visual.eval()
44
+ # for p in visual.parameters():
45
+ # p.requires_grad = False
46
+
47
+ # # Текстовые энкодеры
48
+ # tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
49
+ # ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
50
+ # caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
51
+
52
+ # for p in ocr_encoder.parameters():
53
+ # p.requires_grad = False
54
+ # for p in caption_encoder.parameters():
55
+ # p.requires_grad = False
56
+
57
+ # # Классификатор
58
+ # class ConcatFusionModel(nn.Module):
59
+ # def __init__(self, num_classes, dropout=0.3):
60
+ # super().__init__()
61
+ # self.classifier = nn.Sequential(
62
+ # nn.Linear(2048 + 312 + 312, 512),
63
+ # nn.BatchNorm1d(512),
64
+ # nn.ReLU(),
65
+ # nn.Dropout(dropout),
66
+ # nn.Linear(512, num_classes)
67
+ # )
68
+
69
+ # def forward(self, v, ocr, cap):
70
+ # x = torch.cat([v, ocr, cap], dim=1)
71
+ # return self.classifier(x)
72
+
73
+ # model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
74
+ # model.load_state_dict(torch.load(os.path.join(BASE_DIR, "best_concat_model.pth"), map_location=DEVICE))
75
+ # model.to(DEVICE)
76
+ # model.eval()
77
+
78
+ # # EasyOCR
79
+ # reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
80
+
81
+ # # Трансформы
82
+ # val_transform = transforms.Compose([
83
+ # transforms.Resize(256),
84
+ # transforms.CenterCrop(224),
85
+ # transforms.ToTensor(),
86
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
87
+ # ])
88
+
89
+ # return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE
90
+
91
+
92
+ # # Загружаем всё при старте
93
+ # visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform, DEVICE = load_models()
94
+
95
+
96
+ # # ======================
97
+ # # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
98
+ # # ======================
99
+ # def predict(image, caption_text=""):
100
+ # image = image.convert("RGB")
101
+
102
+ # # OCR
103
+ # ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
104
+ # ocr_text = " ".join(ocr_result) if ocr_result else ""
105
+
106
+ # # Image
107
+ # image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
108
+ # with torch.no_grad():
109
+ # v = visual(image_tensor)
110
+ # v = torch.flatten(v, 1)
111
+
112
+ # # OCR encode
113
+ # ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
114
+ # ocr_ids = ocr_enc["input_ids"].to(DEVICE)
115
+ # ocr_mask = ocr_enc["attention_mask"].to(DEVICE)
116
+ # with torch.no_grad():
117
+ # ocr_out = ocr_encoder(input_ids=ocr_ids, attention_mask=ocr_mask)
118
+ # ocr = ocr_out.last_hidden_state[:, 0]
119
+
120
+ # # Caption encode
121
+ # cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
122
+ # cap_ids = cap_enc["input_ids"].to(DEVICE)
123
+ # cap_mask = cap_enc["attention_mask"].to(DEVICE)
124
+ # with torch.no_grad():
125
+ # cap_out = caption_encoder(input_ids=cap_ids, attention_mask=cap_mask)
126
+ # cap = cap_out.last_hidden_state[:, 0]
127
+
128
+ # # Предсказание
129
+ # with torch.no_grad():
130
+ # logits = model(v, ocr, cap)
131
+ # probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
132
+
133
+ # result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
134
+ # return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
135
+
136
+
137
+ # # ======================
138
+ # # GRADIO ИНТЕРФЕЙС
139
+ # # ======================
140
+ # demo = gr.Interface(
141
+ # fn=predict,
142
+ # inputs=[
143
+ # gr.Image(type="pil", label="Загрузите изображение"),
144
+ # gr.Textbox(label="Подпись (необязательно)", placeholder="Введите текст подписи...")
145
+ # ],
146
+ # outputs=gr.Label(num_top_classes=5, label="Предсказанные категории"),
147
+ # title="Мультимодальный классификатор контента",
148
+ # description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
149
+ # )
150
+
151
+ # if __name__ == "__main__":
152
+ # demo.launch()
153
+
154
+
155
+
156
+
157
+
158
+ import gradio as gr
159
+ import torch
160
+ import torch.nn as nn
161
+ from torchvision import models, transforms
162
+ from PIL import Image
163
+ from transformers import AutoModel, AutoTokenizer
164
+ import easyocr
165
+ import json
166
+ import os
167
+ import numpy as np
168
+
169
+ import spaces
170
+
171
+ # ======================
172
+ # УСТАНОВКА УСТРОЙСТВА
173
+ # ======================
174
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
175
+ print(f"Using device: {DEVICE}")
176
+
177
+ # ======================
178
+ # ПУТИ
179
+ # ======================
180
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
181
+
182
+ # Загрузка названий классов
183
+ with open(os.path.join(BASE_DIR, "class_names.json"), "r") as f:
184
+ id2label = json.load(f)
185
+ id2label = {int(k): v for k, v in id2label.items()}
186
+
187
+ NUM_CLASSES = len(id2label)
188
+
189
+
190
+ # ======================
191
+ # ОПРЕДЕЛЕНИЕ МОДЕЛИ
192
+ # ======================
193
+ class ConcatFusionModel(nn.Module):
194
+ def __init__(self, num_classes, dropout=0.3):
195
+ super().__init__()
196
+ self.classifier = nn.Sequential(
197
+ nn.Linear(2048 + 312 + 312, 512),
198
+ nn.BatchNorm1d(512),
199
+ nn.ReLU(),
200
+ nn.Dropout(dropout),
201
+ nn.Linear(512, 256),
202
+ nn.BatchNorm1d(256),
203
+ nn.ReLU(),
204
+ nn.Dropout(0.3),
205
+ nn.Linear(256, num_classes)
206
+ )
207
+
208
+ def forward(self, v, ocr, cap):
209
+ x = torch.cat([v, ocr, cap], dim=1)
210
+ return self.classifier(x)
211
+
212
+
213
+ # ======================
214
+ # ЗАГРУЗКА МОДЕЛЕЙ
215
+ # ======================
216
+ @gr.cache_resource
217
+ def load_models():
218
+ # Визуальный энкодер (загружаем предобученный из torchvision)
219
+ visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
220
+ visual.fc = nn.Identity() # убираем классификатор
221
+ visual.to(DEVICE)
222
+ visual.eval()
223
+ for p in visual.parameters():
224
+ p.requires_grad = False
225
+
226
+ # Текстовые энкодеры (загружаем предобученные из Hugging Face)
227
+ tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
228
+ ocr_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
229
+ caption_encoder = AutoModel.from_pretrained("cointegrated/rubert-tiny2").to(DEVICE).eval()
230
+
231
+ for p in ocr_encoder.parameters():
232
+ p.requires_grad = False
233
+ for p in caption_encoder.parameters():
234
+ p.requires_grad = False
235
+
236
+ # Классификационная голова (обученная)
237
+ model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
238
+ model.load_state_dict(torch.load(os.path.join(BASE_DIR, "concat_model.pth"), map_location=DEVICE))
239
+ model.to(DEVICE)
240
+ model.eval()
241
+
242
+ # EasyOCR
243
+ reader = easyocr.Reader(["ru", "en"], gpu=(DEVICE.type == "cuda"))
244
+
245
+ # Трансформы для изображений
246
+ val_transform = transforms.Compose([
247
+ transforms.Resize(256),
248
+ transforms.CenterCrop(224),
249
+ transforms.ToTensor(),
250
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
251
+ ])
252
+
253
+ return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
254
+
255
+
256
+ visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform = load_models()
257
+
258
+
259
+ # ======================
260
+ # ФУНКЦИЯ ПРЕДСКАЗАНИЯ
261
+ # ======================
262
+ @spaces.GPU(duration=60)
263
+ def predict(image, caption_text=""):
264
+ image = image.convert("RGB")
265
+
266
+ # OCR
267
+ ocr_result = reader.readtext(np.array(image), detail=0, paragraph=True)
268
+ ocr_text = " ".join(ocr_result) if ocr_result else ""
269
+
270
+ # Image
271
+ image_tensor = val_transform(image).unsqueeze(0).to(DEVICE)
272
+ with torch.no_grad():
273
+ v = visual(image_tensor)
274
+ v = torch.flatten(v, 1)
275
+
276
+ # OCR encode
277
+ ocr_enc = tokenizer(ocr_text, truncation=True, padding="max_length", max_length=64, return_tensors="pt")
278
+ with torch.no_grad():
279
+ ocr_out = ocr_encoder(
280
+ input_ids=ocr_enc["input_ids"].to(DEVICE),
281
+ attention_mask=ocr_enc["attention_mask"].to(DEVICE)
282
+ )
283
+ ocr = ocr_out.last_hidden_state[:, 0]
284
+
285
+ # Caption encode
286
+ cap_enc = tokenizer(caption_text, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
287
+ with torch.no_grad():
288
+ cap_out = caption_encoder(
289
+ input_ids=cap_enc["input_ids"].to(DEVICE),
290
+ attention_mask=cap_enc["attention_mask"].to(DEVICE)
291
+ )
292
+ cap = cap_out.last_hidden_state[:, 0]
293
+
294
+ # Предсказание
295
+ with torch.no_grad():
296
+ logits = model(v, ocr, cap)
297
+ probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
298
+
299
+ result = {id2label[i]: float(probs[i]) for i in range(NUM_CLASSES)}
300
+ return dict(sorted(result.items(), key=lambda x: x[1], reverse=True))
301
+
302
+
303
+ # ======================
304
+ # GRADIO ИНТЕРФЕЙС
305
+ # ======================
306
+ demo = gr.Interface(
307
+ fn=predict,
308
+ inputs=[
309
+ gr.Image(type="pil", label="📸 Загрузите изображение"),
310
+ gr.Textbox(label="📝 Подпись (необязательно)", placeholder="Введите текст подписи...")
311
+ ],
312
+ outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
313
+ title="Мультимодальный классификатор контента",
314
+ description="Модель анализирует изображение + подпись + текст на картинке (EasyOCR)"
315
+ )
316
+
317
+ if __name__ == "__main__":
318
+ demo.launch()
class_names.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "Животные",
3
+ "1": "Кулинария",
4
+ "2": "Путешествия",
5
+ "3": "Развлечения и юмор",
6
+ "4": "СМИ",
7
+ "5": "Торговля и объявления",
8
+ "6": "Увлечения и хобби",
9
+ "7": "Философия и религия"
10
+ }
concat_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d38f4bc713baa58b9c9e2bfd943ef3fe9f79b17b50f46147fb679ce36625af5
3
+ size 333955845
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers
5
+ easyocr
6
+ Pillow
7
+ numpy
8
+ scikit-learn