|
import spaces |
|
import requests |
|
import tempfile |
|
import os |
|
import logging |
|
import cv2 |
|
import pandas as pd |
|
import torch |
|
|
|
from genconvit.pred_func import df_face, load_genconvit, pred_vid |
|
|
|
torch.hub.set_dir('./cache') |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
|
|
|
|
|
|
|
def load_model(): |
|
try: |
|
|
|
ed_weight = 'genconvit_ed_inference' |
|
vae_weight = 'genconvit_vae_inference' |
|
net = 'genconvit' |
|
fp16 = False |
|
model = load_genconvit( net, ed_weight, vae_weight, fp16) |
|
logging.info("Model loaded successfully.") |
|
return model |
|
except Exception as e: |
|
logging.error(f"Error loading model: {e}") |
|
raise |
|
|
|
model = load_model() |
|
|
|
def detect_faces(video_url): |
|
try: |
|
video_name = video_url.split('/')[-1] |
|
response = requests.get(video_url) |
|
response.raise_for_status() |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file: |
|
temp_file.write(response.content) |
|
temp_file_path = temp_file.name |
|
|
|
frames = [] |
|
face_cascade = cv2.CascadeClassifier('./utils/face_detection.xml') |
|
cap = cv2.VideoCapture(temp_file_path) |
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
duration = total_frames / fps |
|
|
|
frame_count = 0 |
|
time_count = 0 |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % int(fps * 5) == 0: |
|
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
|
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) |
|
|
|
for (x, y, w, h) in faces: |
|
cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2) |
|
|
|
frame_name = f"./output/{video_name}_{time_count}.jpg" |
|
frames.append(frame_name) |
|
cv2.imwrite(frame_name, frame) |
|
logging.info(f"Processed frame saved: {frame_name}") |
|
time_count += 1 |
|
|
|
frame_count += 1 |
|
|
|
cap.release() |
|
cv2.destroyAllWindows() |
|
|
|
logging.info(f"Total video duration: {duration:.2f} seconds") |
|
logging.info(f"Total frames processed: {time_count // 5}") |
|
|
|
return frames |
|
except Exception as e: |
|
logging.error(f"Error processing video: {e}") |
|
return [] |
|
|
|
|
|
def genconvit_video_prediction(video_url, factor): |
|
try: |
|
logging.info(f"Processing video URL: {video_url}") |
|
response = requests.get(video_url) |
|
response.raise_for_status() |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file: |
|
temp_file.write(response.content) |
|
temp_file_path = temp_file.name |
|
|
|
num_frames = get_video_frame_count(temp_file_path) |
|
logging.info(f"Number of frames in video: {num_frames}") |
|
logging.info(f"Number of frames to process: {round(num_frames * factor)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
df = df_face(temp_file_path, 11 , model) |
|
if len(df) >= 1: |
|
y, y_val = pred_vid(df, model) |
|
else: |
|
y, y_val = torch.tensor(0).item(), torch.tensor(0.5).item() |
|
|
|
os.unlink(temp_file_path) |
|
|
|
result = { |
|
'score': round(y_val * 100, 2), |
|
'frames_processed': round(num_frames*factor) |
|
} |
|
|
|
logging.info(f"Prediction result: {result}") |
|
return result |
|
except Exception as e: |
|
logging.error(f"Error in video prediction: {e}") |
|
return { |
|
'score': 0, |
|
'prediction': 'ERROR', |
|
'frames_processed': 0 |
|
} |
|
|
|
def get_video_frame_count(video_path): |
|
try: |
|
cap = cv2.VideoCapture(video_path) |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
cap.release() |
|
return frame_count |
|
except Exception as e: |
|
logging.error(f"Error getting video frame count: {e}") |
|
return 0 |
|
|
|
|
|
|