import argparse import os import cv2 import torch import numpy as np from PIL import Image from insightface.app import FaceAnalysis import face_align faceAnalysis = FaceAnalysis(name='buffalo_l') faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512)) from StyleTransferModel_128 import StyleTransferModel import gradio as gr def parse_arguments(): parser = argparse.ArgumentParser(description='Process command line arguments') parser.add_argument('--resolution', type=int, default=128, help='Resolution') #Removed model path return parser.parse_args() def get_device(): return torch.device('cpu') def load_model(model_path): device = get_device() model = StyleTransferModel().to(device) try: model.load_state_dict(torch.load(model_path, map_location=device), strict=False) except FileNotFoundError: print(f"Error: Model file not found at {model_path}") return None model.eval() return model def swap_face(model, target_face, source_face_latent): device = get_device() target_tensor = torch.from_numpy(target_face).to(device) source_tensor = torch.from_numpy(source_face_latent).to(device) with torch.no_grad(): swapped_tensor = model(target_tensor, source_tensor) swapped_face = postprocess_face(swapped_tensor) return swapped_face, swapped_tensor def create_target(target_image, resolution): target_face = faceAnalysis.get(np.array(target_image))[0] aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution) target_face_blob = getBlob(aligned_target_face, (resolution, resolution)) return target_face_blob, M def create_source(source_image): source_face = faceAnalysis.get(np.array(source_image))[0] source_latent = getLatent(source_face) return source_latent def postprocess_face(swapped_tensor): swapped_tensor = swapped_tensor.cpu().numpy() swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1)) swapped_tensor = (swapped_tensor * 255).astype(np.uint8) swapped_face = Image.fromarray(swapped_tensor[0]) return swapped_face def getBlob(aligned_face, size): aligned_face = cv2.resize(aligned_face, size) aligned_face = aligned_face / 255.0 aligned_face = np.transpose(aligned_face, (2, 0, 1)) aligned_face = np.expand_dims(aligned_face, axis=0) aligned_face = torch.from_numpy(aligned_face).float() return aligned_face def getLatent(source_face): return source_face.embedding def blend_swapped_image(swapped_face, target_img, M): swapped_face = np.array(swapped_face) swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0])) mask = np.ones_like(swapped_face) * 255 mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0])) target_img = np.array(target_img) swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L")) return np.array(swapped_face) def process_images(target_image, source_image): args = parse_arguments() args.resolution = 128 model_path = "reswapper-429500.pth" # Hardcoded model path model = load_model(model_path) if model is None: return "Error: Could not load the model. Check the path." target_face_blob, M = create_target(target_image, args.resolution) source_latent = create_source(source_image) swapped_face, _ = swap_face(model, target_face_blob, source_latent) swapped_face = blend_swapped_image(swapped_face, target_image, M) return Image.fromarray(swapped_face) with gr.Blocks() as demo: target_image = gr.Image(label="Target Image", type="pil") source_image = gr.Image(label="Source Image", type="pil") output_image = gr.Image(label="Output Image", type="pil") btn = gr.Button("Swap Face") btn.click(fn=process_images, inputs=[target_image, source_image], outputs=output_image) demo.launch()