fix app.py
Browse files
app.py
CHANGED
|
@@ -152,9 +152,6 @@
|
|
| 152 |
# demo.launch()
|
| 153 |
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
import gradio as gr
|
| 159 |
import torch
|
| 160 |
import torch.nn as nn
|
|
@@ -213,7 +210,7 @@ class ConcatFusionModel(nn.Module):
|
|
| 213 |
# ======================
|
| 214 |
# ЗАГРУЗКА МОДЕЛЕЙ
|
| 215 |
# ======================
|
| 216 |
-
@gr.
|
| 217 |
def load_models():
|
| 218 |
# Визуальный энкодер (загружаем предобученный из torchvision)
|
| 219 |
visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
@@ -225,8 +222,10 @@ def load_models():
|
|
| 225 |
|
| 226 |
# Текстовые энкодеры (загружаем предобученные из Hugging Face)
|
| 227 |
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
|
| 228 |
-
ocr_encoder = AutoModel.from_pretrained(
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
|
| 231 |
for p in ocr_encoder.parameters():
|
| 232 |
p.requires_grad = False
|
|
@@ -235,7 +234,8 @@ def load_models():
|
|
| 235 |
|
| 236 |
# Классификационная голова (обученная)
|
| 237 |
model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
|
| 238 |
-
model.load_state_dict(torch.load(os.path.join(
|
|
|
|
| 239 |
model.to(DEVICE)
|
| 240 |
model.eval()
|
| 241 |
|
|
@@ -247,7 +247,8 @@ def load_models():
|
|
| 247 |
transforms.Resize(256),
|
| 248 |
transforms.CenterCrop(224),
|
| 249 |
transforms.ToTensor(),
|
| 250 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
|
|
|
|
| 251 |
])
|
| 252 |
|
| 253 |
return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
|
|
@@ -274,7 +275,8 @@ def predict(image, caption_text=""):
|
|
| 274 |
v = torch.flatten(v, 1)
|
| 275 |
|
| 276 |
# OCR encode
|
| 277 |
-
ocr_enc = tokenizer(ocr_text, truncation=True,
|
|
|
|
| 278 |
with torch.no_grad():
|
| 279 |
ocr_out = ocr_encoder(
|
| 280 |
input_ids=ocr_enc["input_ids"].to(DEVICE),
|
|
@@ -283,7 +285,8 @@ def predict(image, caption_text=""):
|
|
| 283 |
ocr = ocr_out.last_hidden_state[:, 0]
|
| 284 |
|
| 285 |
# Caption encode
|
| 286 |
-
cap_enc = tokenizer(caption_text, truncation=True,
|
|
|
|
| 287 |
with torch.no_grad():
|
| 288 |
cap_out = caption_encoder(
|
| 289 |
input_ids=cap_enc["input_ids"].to(DEVICE),
|
|
@@ -307,7 +310,8 @@ demo = gr.Interface(
|
|
| 307 |
fn=predict,
|
| 308 |
inputs=[
|
| 309 |
gr.Image(type="pil", label="📸 Загрузите изображение"),
|
| 310 |
-
gr.Textbox(label="📝 Подпись (необязательно)",
|
|
|
|
| 311 |
],
|
| 312 |
outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
|
| 313 |
title="Мультимодальный классификатор контента",
|
|
@@ -315,4 +319,4 @@ demo = gr.Interface(
|
|
| 315 |
)
|
| 316 |
|
| 317 |
if __name__ == "__main__":
|
| 318 |
-
demo.launch()
|
|
|
|
| 152 |
# demo.launch()
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
| 155 |
import gradio as gr
|
| 156 |
import torch
|
| 157 |
import torch.nn as nn
|
|
|
|
| 210 |
# ======================
|
| 211 |
# ЗАГРУЗКА МОДЕЛЕЙ
|
| 212 |
# ======================
|
| 213 |
+
@gr.cache
|
| 214 |
def load_models():
|
| 215 |
# Визуальный энкодер (загружаем предобученный из torchvision)
|
| 216 |
visual = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
|
|
| 222 |
|
| 223 |
# Текстовые энкодеры (загружаем предобученные из Hugging Face)
|
| 224 |
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
|
| 225 |
+
ocr_encoder = AutoModel.from_pretrained(
|
| 226 |
+
"cointegrated/rubert-tiny2").to(DEVICE).eval()
|
| 227 |
+
caption_encoder = AutoModel.from_pretrained(
|
| 228 |
+
"cointegrated/rubert-tiny2").to(DEVICE).eval()
|
| 229 |
|
| 230 |
for p in ocr_encoder.parameters():
|
| 231 |
p.requires_grad = False
|
|
|
|
| 234 |
|
| 235 |
# Классификационная голова (обученная)
|
| 236 |
model = ConcatFusionModel(NUM_CLASSES, dropout=0.3)
|
| 237 |
+
model.load_state_dict(torch.load(os.path.join(
|
| 238 |
+
BASE_DIR, "concat_model.pth"), map_location=DEVICE))
|
| 239 |
model.to(DEVICE)
|
| 240 |
model.eval()
|
| 241 |
|
|
|
|
| 247 |
transforms.Resize(256),
|
| 248 |
transforms.CenterCrop(224),
|
| 249 |
transforms.ToTensor(),
|
| 250 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
|
| 251 |
+
0.229, 0.224, 0.225]),
|
| 252 |
])
|
| 253 |
|
| 254 |
return visual, ocr_encoder, caption_encoder, tokenizer, model, reader, val_transform
|
|
|
|
| 275 |
v = torch.flatten(v, 1)
|
| 276 |
|
| 277 |
# OCR encode
|
| 278 |
+
ocr_enc = tokenizer(ocr_text, truncation=True,
|
| 279 |
+
padding="max_length", max_length=64, return_tensors="pt")
|
| 280 |
with torch.no_grad():
|
| 281 |
ocr_out = ocr_encoder(
|
| 282 |
input_ids=ocr_enc["input_ids"].to(DEVICE),
|
|
|
|
| 285 |
ocr = ocr_out.last_hidden_state[:, 0]
|
| 286 |
|
| 287 |
# Caption encode
|
| 288 |
+
cap_enc = tokenizer(caption_text, truncation=True,
|
| 289 |
+
padding="max_length", max_length=128, return_tensors="pt")
|
| 290 |
with torch.no_grad():
|
| 291 |
cap_out = caption_encoder(
|
| 292 |
input_ids=cap_enc["input_ids"].to(DEVICE),
|
|
|
|
| 310 |
fn=predict,
|
| 311 |
inputs=[
|
| 312 |
gr.Image(type="pil", label="📸 Загрузите изображение"),
|
| 313 |
+
gr.Textbox(label="📝 Подпись (необязательно)",
|
| 314 |
+
placeholder="Введите текст подписи...")
|
| 315 |
],
|
| 316 |
outputs=gr.Label(num_top_classes=5, label="🎯 Предсказанные категории"),
|
| 317 |
title="Мультимодальный классификатор контента",
|
|
|
|
| 319 |
)
|
| 320 |
|
| 321 |
if __name__ == "__main__":
|
| 322 |
+
demo.launch()
|