Api_model / API_Model.py
MO-12's picture
Update API_Model.py
d00772b verified
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
import os, cv2, uuid, json
import numpy as np
import gdown
model_path = "checkpoint_epoch_1.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# تحميل النموذج من Google Drive لو مش موجود
if not os.path.exists(model_path):
print("Downloading checkpoint...")
url = "https://drive.google.com/uc?id=1dIaptYPq-1fgo0yoBoPlDsbIfs3BEqJI"
gdown.download(url, model_path, quiet=False)
model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base", num_labels=3)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval().to(device)
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
def predict_gradio(video):
import tempfile
import shutil
# أنشئ مجلد مؤقت للعمل
video_id = str(uuid.uuid4())
work_dir = f"./temp/{video_id}"
os.makedirs(work_dir, exist_ok=True)
# نحفظ الفيديو المرفوع على هيئة ملف مؤقت mp4
temp_video_path = os.path.join(work_dir, "input.mp4")
if isinstance(video, str):
# Gradio بيرسل أحيانًا مسار الملف
shutil.copy(video, temp_video_path)
else:
# Gradio بيرسل BytesIO stream (مش شائع بس نغطيه)
with open(temp_video_path, "wb") as f:
f.write(video.read())
# نحاول نفتح الفيديو
cap = cv2.VideoCapture(temp_video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if fps == 0 or fps != fps: # NaN or 0
return [{"error": "Invalid or unreadable video."}], ""
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
resized = cv2.resize(frame, (224, 224))
frames.append(resized)
cap.release()
segment_size = int(fps * 5)
predictions = []
output_segments = []
for i in range(0, len(frames), segment_size):
segment = frames[i:i+segment_size]
if len(segment) < 16:
continue
indices = np.linspace(0, len(segment)-1, 16).astype(int)
sampled_frames = [segment[idx] for idx in indices]
inputs = feature_extractor(sampled_frames, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
confidence, pred = torch.max(probs, dim=1)
if confidence.item() > 0.70:
label = label_map[pred.item()]
start_time = i / fps
end_time = min((i + segment_size), len(frames)) / fps
predictions.append({
"start": round(start_time, 2),
"end": round(end_time, 2),
"label": label,
"confidence": round(confidence.item(), 3)
})
output_segments.append(segment)
out_path = f"{work_dir}/summary.mp4"
if output_segments:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(out_path, fourcc, fps, (224, 224))
for seg in output_segments:
for frame in seg:
out.write(frame)
out.release()
return predictions, out_path
else:
return predictions, ""