mimosa-ai / prediction.py
vivekk3's picture
Upload folder using huggingface_hub
9c4b01e verified
import spaces
import requests
import tempfile
import os
import logging
import cv2
import pandas as pd
import torch
# from genconvit.config import load_config
from genconvit.pred_func import df_face, load_genconvit, pred_vid
torch.hub.set_dir('./cache')
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache"
# Set up logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_model():
try:
# config = load_config()
ed_weight = 'genconvit_ed_inference'
vae_weight = 'genconvit_vae_inference'
net = 'genconvit'
fp16 = False
model = load_genconvit( net, ed_weight, vae_weight, fp16)
logging.info("Model loaded successfully.")
return model
except Exception as e:
logging.error(f"Error loading model: {e}")
raise
model = load_model()
def detect_faces(video_url):
try:
video_name = video_url.split('/')[-1]
response = requests.get(video_url)
response.raise_for_status() # Raise an exception for HTTP errors
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
temp_file.write(response.content)
temp_file_path = temp_file.name
frames = []
face_cascade = cv2.CascadeClassifier('./utils/face_detection.xml')
cap = cv2.VideoCapture(temp_file_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps
frame_count = 0
time_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frame_count % int(fps * 5) == 0:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))
for (x, y, w, h) in faces:
cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2)
frame_name = f"./output/{video_name}_{time_count}.jpg"
frames.append(frame_name)
cv2.imwrite(frame_name, frame)
logging.info(f"Processed frame saved: {frame_name}")
time_count += 1
frame_count += 1
cap.release()
cv2.destroyAllWindows()
logging.info(f"Total video duration: {duration:.2f} seconds")
logging.info(f"Total frames processed: {time_count // 5}")
return frames
except Exception as e:
logging.error(f"Error processing video: {e}")
return []
# @spaces.GPU(duration=300)
def genconvit_video_prediction(video_url, factor):
try:
logging.info(f"Processing video URL: {video_url}")
response = requests.get(video_url)
response.raise_for_status() # Raise an exception for HTTP errors
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
temp_file.write(response.content)
temp_file_path = temp_file.name
num_frames = get_video_frame_count(temp_file_path)
logging.info(f"Number of frames in video: {num_frames}")
logging.info(f"Number of frames to process: {round(num_frames * factor)}")
# rounf num_frames by2 to nearest integer
# df = df_face(temp_file_path, int(round(num_frames * factor)) , model)
# df = df_face(temp_file_path, int(round(num_frames * factor)) , model)
df = df_face(temp_file_path, 11 , model)
if len(df) >= 1:
y, y_val = pred_vid(df, model)
else:
y, y_val = torch.tensor(0).item(), torch.tensor(0.5).item()
os.unlink(temp_file_path) # Clean up temporary file
result = {
'score': round(y_val * 100, 2),
'frames_processed': round(num_frames*factor)
}
logging.info(f"Prediction result: {result}")
return result
except Exception as e:
logging.error(f"Error in video prediction: {e}")
return {
'score': 0,
'prediction': 'ERROR',
'frames_processed': 0
}
def get_video_frame_count(video_path):
try:
cap = cv2.VideoCapture(video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
return frame_count
except Exception as e:
logging.error(f"Error getting video frame count: {e}")
return 0