Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import torch | |
| import cv2 | |
| import numpy as np | |
| from torchvision.transforms.functional import normalize | |
| from tqdm import tqdm | |
| from PIL import Image, ImageOps | |
| import random | |
| import os | |
| import requests | |
| from insightface.app import FaceAnalysis | |
| from facexlib.parsing import init_parsing_model | |
| from typing import Union, Optional, Tuple, List | |
| # --- Helper Functions (Unchanged) --- | |
| def tensor_to_cv2_img(tensor_frame: torch.Tensor) -> np.ndarray: | |
| """Converts a single RGB torch tensor to a BGR OpenCV image.""" | |
| img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8) | |
| return cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| def tensor_to_cv2_bgra_img(tensor_frame: torch.Tensor) -> np.ndarray: | |
| """Converts a single RGBA torch tensor to a BGRA OpenCV image.""" | |
| if tensor_frame.shape[2] != 4: | |
| raise ValueError("Input tensor must be an RGBA image with 4 channels.") | |
| img_np = (tensor_frame.cpu().numpy() * 255).astype(np.uint8) | |
| return cv2.cvtColor(img_np, cv2.COLOR_RGBA2BGRA) | |
| def pil_to_tensor(image: Image.Image) -> torch.Tensor: | |
| """Converts a PIL image to a torch tensor.""" | |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0) | |
| class VideoMaskGenerator: | |
| def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None): | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = device | |
| print(f"Using device: {self.device}") | |
| providers = ["CUDAExecutionProvider"] if self.device.type == "cuda" else ["CPUExecutionProvider"] | |
| # Initialize face detection and landmark model (antelopev2 provides both) | |
| self.detection_model = FaceAnalysis(name="antelopev2", root=antelopv2_path, providers=providers) | |
| self.detection_model.prepare(ctx_id=0, det_size=(640, 640)) | |
| # Initialize face parsing model | |
| self.parsing_model = init_parsing_model(model_name="bisenet", device=self.device) | |
| self.parsing_model.eval() | |
| print("FaceProcessor initialized successfully.") | |
| def process( | |
| self, | |
| video_path: str, | |
| face_image: Union[str, Image.Image], | |
| confidence_threshold: float = 0.5, | |
| face_crop_scale: float = 1.5, | |
| dilation_kernel_size: int = 10, | |
| feather_amount: int = 21, | |
| random_horizontal_flip_chance: float = 0.0, | |
| match_angle_and_size: bool = True | |
| ) -> Tuple[np.ndarray, np.ndarray, int, int, int]: | |
| """ | |
| Processes a video to replace a face with a provided face image. | |
| Args: | |
| video_path (str): Path to the input video file. | |
| face_image (Union[str, Image.Image]): Path or PIL image of the face to paste. | |
| confidence_threshold (float): Confidence threshold for face detection. | |
| face_crop_scale (float): Scale factor for cropping the detected face box. | |
| dilation_kernel_size (int): Kernel size for mask dilation. | |
| feather_amount (int): Amount of feathering for the mask edges. | |
| random_horizontal_flip_chance (float): Chance to flip the source face horizontally. | |
| match_angle_and_size (bool): Whether to use landmark matching for rotation and scale. | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray, int, int, int]: | |
| - Processed video as a numpy array (F, H, W, C). | |
| - Generated masks as a numpy array (F, H, W). | |
| - Width of the processed video. | |
| - Height of the processed video. | |
| - Number of frames in the processed video. | |
| """ | |
| # --- Video Pre-processing --- | |
| if not os.path.exists(video_path): | |
| raise FileNotFoundError(f"Video file not found at: {video_path}") | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| if not frames: | |
| raise ValueError("Could not read any frames from the video.") | |
| video_np = np.array(frames) | |
| h, w = video_np.shape[1], video_np.shape[2] | |
| new_h, new_w = (h // 16) * 16, (w // 16) * 16 | |
| y_start = (h - new_h) // 2 | |
| x_start = (w - new_w) // 2 | |
| video_cropped = video_np[:, y_start:y_start+new_h, x_start:x_start+new_w, :] | |
| num_frames = video_cropped.shape[0] | |
| target_frames = (num_frames // 4) * 4 + 1 | |
| video_trimmed = video_cropped[:target_frames] | |
| final_h, final_w, final_frames = video_trimmed.shape[1], video_trimmed.shape[2], video_trimmed.shape[0] | |
| print(f"Video pre-processed: {final_w}x{final_h}, {final_frames} frames.") | |
| # --- Face Image Pre-processing & Source Landmark Extraction --- | |
| if isinstance(face_image, str): | |
| if face_image.startswith("http"): | |
| face_image = Image.open(requests.get(face_image, stream=True, timeout=10).raw) | |
| else: | |
| face_image = Image.open(face_image) | |
| face_image = ImageOps.exif_transpose(face_image).convert("RGBA") | |
| face_rgba_tensor = pil_to_tensor(face_image) | |
| face_to_paste_cv2 = tensor_to_cv2_bgra_img(face_rgba_tensor) | |
| source_kpts = None | |
| if match_angle_and_size: | |
| # Use insightface (antelopev2) to get landmarks from the source face image | |
| source_face_bgr = cv2.cvtColor(face_to_paste_cv2, cv2.COLOR_BGRA2BGR) | |
| source_faces = self.detection_model.get(source_face_bgr) | |
| if source_faces: | |
| # Use the landmarks from the first (and likely only) detected face | |
| source_kpts = source_faces[0].kps | |
| else: | |
| print("[Warning] No face or landmarks found in source image. Disabling angle matching.") | |
| match_angle_and_size = False | |
| face_to_paste_pil = Image.fromarray((face_rgba_tensor.cpu().numpy() * 255).astype(np.uint8), 'RGBA') | |
| # --- Main Processing Loop --- | |
| processed_frames_list = [] | |
| mask_list = [] | |
| for i in tqdm(range(final_frames), desc="Pasting face onto frames"): | |
| frame_rgb = video_trimmed[i] | |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| # Use insightface for detection and landmarks | |
| faces = self.detection_model.get(frame_bgr) | |
| pasted = False | |
| final_mask = np.zeros((final_h, final_w), dtype=np.uint8) | |
| if faces: | |
| largest_face = max(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])) | |
| if largest_face.det_score > confidence_threshold: | |
| # **MODIFIED BLOCK**: Use insightface landmarks for affine transform | |
| if match_angle_and_size and source_kpts is not None: | |
| target_kpts = largest_face.kps # Get landmarks directly from the detected face | |
| # Estimate the transformation matrix | |
| M, _ = cv2.estimateAffinePartial2D(source_kpts, target_kpts, method=cv2.LMEDS) | |
| if M is not None: | |
| # Split the RGBA source face for separate warping | |
| b, g, r, a = cv2.split(face_to_paste_cv2) | |
| source_rgb_cv2 = cv2.merge([r, g, b]) | |
| # Warp the face and its alpha channel | |
| warped_face = cv2.warpAffine(source_rgb_cv2, M, (final_w, final_h)) | |
| warped_alpha = cv2.warpAffine(a, M, (final_w, final_h)) | |
| # Blend the warped face onto the frame using the warped alpha channel | |
| alpha_float = warped_alpha.astype(np.float32) / 255.0 | |
| alpha_expanded = np.expand_dims(alpha_float, axis=2) | |
| frame_rgb = (1.0 - alpha_expanded) * frame_rgb + alpha_expanded * warped_face | |
| frame_rgb = frame_rgb.astype(np.uint8) | |
| final_mask = warped_alpha | |
| pasted = True | |
| # Fallback to simple box-pasting if angle matching is off or fails | |
| if not pasted: | |
| x1, y1, x2, y2 = map(int, largest_face.bbox) | |
| center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 | |
| side_len = int(max(x2 - x1, y2 - y1) * face_crop_scale) | |
| half_side = side_len // 2 | |
| crop_y1, crop_x1 = max(center_y - half_side, 0), max(center_x - half_side, 0) | |
| crop_y2, crop_x2 = min(center_y + half_side, final_h), min(center_x + half_side, final_w) | |
| box_w, box_h = crop_x2 - crop_x1, crop_y2 - crop_y1 | |
| if box_w > 0 and box_h > 0: | |
| source_img = face_to_paste_pil.copy() | |
| if random.random() < random_horizontal_flip_chance: | |
| source_img = source_img.transpose(Image.FLIP_LEFT_RIGHT) | |
| face_resized = source_img.resize((box_w, box_h), Image.Resampling.LANCZOS) | |
| target_frame_pil = Image.fromarray(frame_rgb) | |
| # --- Mask Generation using BiSeNet --- | |
| face_crop_bgr = cv2.cvtColor(frame_rgb[crop_y1:crop_y2, crop_x1:crop_x2], cv2.COLOR_RGB2BGR) | |
| if face_crop_bgr.size > 0: | |
| face_resized_512 = cv2.resize(face_crop_bgr, (512, 512), interpolation=cv2.INTER_AREA) | |
| face_rgb_512 = cv2.cvtColor(face_resized_512, cv2.COLOR_BGR2RGB) | |
| face_tensor_in = torch.from_numpy(face_rgb_512.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| normalized_face = normalize(face_tensor_in, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| parsing_map = self.parsing_model(normalized_face)[0].argmax(dim=1, keepdim=True) | |
| parsing_map_np = parsing_map.squeeze().cpu().numpy().astype(np.uint8) | |
| parts_to_include = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] # All face parts | |
| final_mask_512 = np.isin(parsing_map_np, parts_to_include).astype(np.uint8) * 255 | |
| if dilation_kernel_size > 0: | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilation_kernel_size, dilation_kernel_size)) | |
| final_mask_512 = cv2.dilate(final_mask_512, kernel, iterations=1) | |
| if feather_amount > 0: | |
| if feather_amount % 2 == 0: feather_amount += 1 | |
| final_mask_512 = cv2.GaussianBlur(final_mask_512, (feather_amount, feather_amount), 0) | |
| mask_resized_to_crop = cv2.resize(final_mask_512, (box_w, box_h), interpolation=cv2.INTER_LINEAR) | |
| generated_mask_pil = Image.fromarray(mask_resized_to_crop, mode='L') | |
| target_frame_pil.paste(face_resized, (crop_x1, crop_y1), mask=generated_mask_pil) | |
| frame_rgb = np.array(target_frame_pil) | |
| final_mask[crop_y1:crop_y2, crop_x1:crop_x2] = mask_resized_to_crop | |
| processed_frames_list.append(frame_rgb) | |
| mask_list.append(final_mask) | |
| output_video = np.stack(processed_frames_list) | |
| # Ensure mask has a channel dimension for consistency | |
| output_masks = np.stack(mask_list)[..., np.newaxis] | |
| return (output_video, output_masks, final_w, final_h, final_frames) | 
