magedsar7an commited on
Commit
270a762
·
verified ·
1 Parent(s): d444bd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py CHANGED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎬 Multilingual Video Classification (Beautiful + Voice Icon)
2
+ import os, json, base64
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+ import torch, cv2, numpy as np
7
+ from PIL import Image
8
+ from gtts import gTTS
9
+ from transformers import (
10
+ BlipProcessor, BlipForConditionalGeneration,
11
+ AutoTokenizer, AutoModelForSequenceClassification,
12
+ AutoModelForSeq2SeqLM
13
+ )
14
+
15
+ # ---------- CONFIG ----------
16
+ MODEL_ID = "magedsar7an/caption-cls-en-small" # ← your HF model repo
17
+ FRAMES_PER_VIDEO = 6
18
+ FRAME_SIZE = 384
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ SUPPORTED_LANGS = {
22
+ "en":"English","ar":"Arabic","fr":"French","tr":"Turkish",
23
+ "es":"Spanish","de":"German","hi":"Hindi","id":"Indonesian"
24
+ }
25
+ MARIAN_TO_EN = {
26
+ "ar":"Helsinki-NLP/opus-mt-ar-en",
27
+ "fr":"Helsinki-NLP/opus-mt-fr-en",
28
+ "tr":"Helsinki-NLP/opus-mt-tr-en",
29
+ "es":"Helsinki-NLP/opus-mt-es-en",
30
+ "de":"Helsinki-NLP/opus-mt-de-en",
31
+ "hi":"Helsinki-NLP/opus-mt-hi-en",
32
+ "id":"Helsinki-NLP/opus-mt-id-en",
33
+ }
34
+ LABEL_TRANSLATIONS = {
35
+ "ar": {"clap":"تصفيق","drink":"يشرب","hug":"عناق","kick_ball":"ركل الكرة",
36
+ "kiss":"قبلة","run":"يجري","sit":"يجلس","wave":"يلوح"},
37
+ "tr": {"clap":"alkış","drink":"içmek","hug":"sarılmak","kick_ball":"topa tekme",
38
+ "kiss":"öpücük","run":"koşmak","sit":"oturmak","wave":"el sallamak"},
39
+ "fr": {"clap":"applaudir","drink":"boire","hug":"embrasser","kick_ball":"frapper le ballon",
40
+ "kiss":"baiser","run":"courir","sit":"s’asseoir","wave":"saluer"},
41
+ "es": {"clap":"aplaudir","drink":"beber","hug":"abrazar","kick_ball":"patear la pelota",
42
+ "kiss":"besar","run":"correr","sit":"sentarse","wave":"saludar"},
43
+ "de": {"clap":"klatschen","drink":"trinken","hug":"umarmen","kick_ball":"den Ball treten",
44
+ "kiss":"küssen","run":"laufen","sit":"sitzen","wave":"winken"},
45
+ "hi": {"clap":"ताली बजाना","drink":"पीना","hug":"गले लगाना","kick_ball":"गेंद को मारना",
46
+ "kiss":"चूमना","run":"दौड़ना","sit":"बैठना","wave":"हाथ हिलाना"},
47
+ "id": {"clap":"bertepuk tangan","drink":"minum","hug":"berpelukan","kick_ball":"menendang bola",
48
+ "kiss":"cium","run":"berlari","sit":"duduk","wave":"melambaikan tangan"},
49
+ }
50
+
51
+ # ---------- LOAD MODELS ----------
52
+ print("Loading BLIP captioner...")
53
+ blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
54
+ blip = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device).eval()
55
+
56
+ print("Loading English classifier from HF Hub...")
57
+ tok = AutoTokenizer.from_pretrained(MODEL_ID)
58
+ cls = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(device).eval()
59
+
60
+ # id2label from model config (you embedded it during upload)
61
+ cfg_map = getattr(cls.config, "id2label", None)
62
+ if not cfg_map:
63
+ raise RuntimeError("id2label not found in config.json; add it to your HF model.")
64
+ # normalize keys to int
65
+ id2label = {int(k): v for k, v in (cfg_map.items() if isinstance(cfg_map, dict) else enumerate(cfg_map))}
66
+ print("✅ Models loaded successfully!")
67
+
68
+ # ---------- HELPERS ----------
69
+ def _resolve_video_path(video):
70
+ if isinstance(video, str):
71
+ return video if os.path.exists(video) else None
72
+ if isinstance(video, dict):
73
+ p = video.get("path") or video.get("name")
74
+ return p if (isinstance(p, str) and os.path.exists(p)) else None
75
+ name = getattr(video, "name", None)
76
+ if isinstance(name, str) and os.path.exists(name):
77
+ return name
78
+ return None
79
+
80
+ def extract_frames(video_path, k=6, size=384):
81
+ cap = cv2.VideoCapture(video_path)
82
+ if not cap.isOpened():
83
+ return []
84
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
85
+ idxs = np.linspace(0, max(total - 1, 0), num=k, dtype=int) if total > 0 else np.linspace(0, 240, num=k, dtype=int)
86
+ frames = []
87
+ for i in idxs:
88
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(i))
89
+ ok, frame = cap.read()
90
+ if not ok or frame is None:
91
+ continue
92
+ h, w = frame.shape[:2]
93
+ if h <= 0 or w <= 0:
94
+ continue
95
+ if h < w:
96
+ new_h = size; new_w = int(w * (size / h))
97
+ else:
98
+ new_w = size; new_h = int(h * (size / w))
99
+ frame = cv2.resize(frame, (new_w, new_h))
100
+ frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
101
+ cap.release()
102
+ return frames
103
+
104
+ def blip_caption(img):
105
+ inputs = blip_proc(images=img, return_tensors="pt").to(device)
106
+ with torch.no_grad():
107
+ out = blip.generate(**inputs, max_new_tokens=30)
108
+ return blip_proc.decode(out[0], skip_special_tokens=True).strip()
109
+
110
+ def translate_to_en(texts, lang):
111
+ if lang == "en": return texts
112
+ model_name = MARIAN_TO_EN.get(lang)
113
+ if not model_name: return texts
114
+ try:
115
+ tok_tr = AutoTokenizer.from_pretrained(model_name)
116
+ mt = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device).eval()
117
+ outs = []
118
+ for i in range(0, len(texts), 16):
119
+ batch = texts[i:i + 16]
120
+ enc = tok_tr(batch, return_tensors="pt", padding=True, truncation=True).to(device)
121
+ with torch.no_grad():
122
+ gen = mt.generate(**enc, max_new_tokens=120)
123
+ outs.extend(tok_tr.batch_decode(gen, skip_special_tokens=True))
124
+ return outs
125
+ except Exception as e:
126
+ print(f"⚠️ Translation failed: {e}")
127
+ return texts
128
+
129
+ def classify(texts):
130
+ enc = tok(texts, return_tensors="pt", padding=True, truncation=True).to(device)
131
+ with torch.no_grad():
132
+ logits = cls(**enc).logits
133
+ probs = torch.softmax(logits, dim=-1).cpu().numpy()
134
+ return probs
135
+
136
+ # ---------- MAIN FN ----------
137
+ def classify_video(video, lang):
138
+ try:
139
+ if not video:
140
+ return "<div style='color:orange;'>⚠️ Please upload a video first.</div>"
141
+
142
+ video_path = _resolve_video_path(video)
143
+ if not video_path:
144
+ return "<div style='color:red;'>❌ Could not find uploaded video path from Gradio input.</div>"
145
+
146
+ frames = extract_frames(video_path, FRAMES_PER_VIDEO, FRAME_SIZE)
147
+ if not frames:
148
+ return "<div style='color:red;'>❌ Could not extract frames. OpenCV could not decode the video.</div>"
149
+
150
+ captions = [blip_caption(f) for f in frames]
151
+ en_caps = translate_to_en(captions, lang)
152
+ probs = classify(en_caps)
153
+ pred = id2label[int(np.argmax(probs.mean(axis=0)))]
154
+ localized = LABEL_TRANSLATIONS.get(lang, {}).get(pred, pred)
155
+
156
+ # 🔊 TTS (fail-soft if blocked)
157
+ audio_b64 = ""
158
+ try:
159
+ tts = gTTS(localized, lang=lang if lang in SUPPORTED_LANGS else "en")
160
+ audio_path = "pred_voice.mp3"
161
+ tts.save(audio_path)
162
+ with open(audio_path, "rb") as f:
163
+ audio_b64 = base64.b64encode(f.read()).decode()
164
+ except Exception as e:
165
+ print(f"⚠️ TTS failed: {e}")
166
+
167
+ # 🎨 Card
168
+ lang_name = SUPPORTED_LANGS.get(lang, "Unknown")
169
+ btn = f"<button onclick=\"new Audio('data:audio/mp3;base64,{audio_b64}').play()\" style='background:#00b4d8;color:white;border:none;border-radius:50%;width:70px;height:70px;cursor:pointer;font-size:1.8em;box-shadow:0 2px 10px rgba(0,180,216,0.5);'>🔊</button>" if audio_b64 else ""
170
+ html = f"""
171
+ <div style='background: linear-gradient(135deg,#141e30,#243b55);border-radius:16px;padding:35px;color:white;text-align:center;font-family:"Poppins",sans-serif;box-shadow:0 4px 20px rgba(0,0,0,0.3);'>
172
+ <h2 style='color:#00b4d8;font-weight:600;margin-bottom:10px;'>🎬 Action Detected</h2>
173
+ <h1 style='font-size:2.5em;margin:12px 0;'>{localized}</h1>
174
+ {btn}
175
+ <p style='opacity:0.8;margin-top:14px;font-size:1.1em;'>({lang_name})</p>
176
+ </div>
177
+ """
178
+ return html
179
+
180
+ except Exception as e:
181
+ import traceback; traceback.print_exc()
182
+ return f"<div style='color:red;font-weight:bold;'>❌ Error:<br>{e}</div>"
183
+
184
+ # ---------- GRADIO UI ----------
185
+ custom_css = """
186
+ .gradio-container {
187
+ background: linear-gradient(135deg,#0f2027,#203a43,#2c5364);
188
+ color: white;
189
+ }
190
+ h1,h2,h3,label,p,.description {color: white !important;}
191
+ footer {display:none !important;}
192
+ """
193
+ title = "🎬 Multilingual Video Classification (Beautiful + Voice Icon)"
194
+ description = """
195
+ Upload your video and choose a language.
196
+ The model predicts the action and shows a **beautiful card** 🌍
197
+ Click the 🔊 icon to **hear the word pronounced** in that language.
198
+ """
199
+ iface = gr.Interface(
200
+ fn=classify_video,
201
+ inputs=[
202
+ gr.Video(label="🎥 Upload Video", sources=["upload"], format="mp4"),
203
+ gr.Radio(choices=list(SUPPORTED_LANGS.keys()), value="en", label="🌍 Choose Language"),
204
+ ],
205
+ outputs=gr.HTML(label="✨ Prediction Result"),
206
+ title=title,
207
+ description=description,
208
+ theme="gradio/soft",
209
+ css=custom_css,
210
+ )
211
+ if __name__ == "__main__":
212
+ iface.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))