hg / app.py
victorgg's picture
Update app.py
cf1cb9e verified
import argparse
import os
import cv2
import torch
import numpy as np
from PIL import Image
from insightface.app import FaceAnalysis
import face_align
faceAnalysis = FaceAnalysis(name='buffalo_l')
faceAnalysis.prepare(ctx_id=-1, det_size=(512, 512))
from StyleTransferModel_128 import StyleTransferModel
import gradio as gr
def parse_arguments():
parser = argparse.ArgumentParser(description='Process command line arguments')
parser.add_argument('--resolution', type=int, default=128, help='Resolution') #Removed model path
return parser.parse_args()
def get_device():
return torch.device('cpu')
def load_model(model_path):
device = get_device()
model = StyleTransferModel().to(device)
try:
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
except FileNotFoundError:
print(f"Error: Model file not found at {model_path}")
return None
model.eval()
return model
def swap_face(model, target_face, source_face_latent):
device = get_device()
target_tensor = torch.from_numpy(target_face).to(device)
source_tensor = torch.from_numpy(source_face_latent).to(device)
with torch.no_grad():
swapped_tensor = model(target_tensor, source_tensor)
swapped_face = postprocess_face(swapped_tensor)
return swapped_face, swapped_tensor
def create_target(target_image, resolution):
target_face = faceAnalysis.get(np.array(target_image))[0]
aligned_target_face, M = face_align.norm_crop2(np.array(target_image), target_face.kps, resolution)
target_face_blob = getBlob(aligned_target_face, (resolution, resolution))
return target_face_blob, M
def create_source(source_image):
source_face = faceAnalysis.get(np.array(source_image))[0]
source_latent = getLatent(source_face)
return source_latent
def postprocess_face(swapped_tensor):
swapped_tensor = swapped_tensor.cpu().numpy()
swapped_tensor = np.transpose(swapped_tensor, (0, 2, 3, 1))
swapped_tensor = (swapped_tensor * 255).astype(np.uint8)
swapped_face = Image.fromarray(swapped_tensor[0])
return swapped_face
def getBlob(aligned_face, size):
aligned_face = cv2.resize(aligned_face, size)
aligned_face = aligned_face / 255.0
aligned_face = np.transpose(aligned_face, (2, 0, 1))
aligned_face = np.expand_dims(aligned_face, axis=0)
aligned_face = torch.from_numpy(aligned_face).float()
return aligned_face
def getLatent(source_face):
return source_face.embedding
def blend_swapped_image(swapped_face, target_img, M):
swapped_face = np.array(swapped_face)
swapped_face = cv2.warpAffine(swapped_face, M, (target_img.shape[1], target_img.shape[0]))
mask = np.ones_like(swapped_face) * 255
mask = cv2.warpAffine(mask, M, (target_img.shape[1], target_img.shape[0]))
target_img = np.array(target_img)
swapped_face = Image.blend(Image.fromarray(target_img), Image.fromarray(swapped_face), Image.fromarray(mask).convert("L"))
return np.array(swapped_face)
def process_images(target_image, source_image):
args = parse_arguments()
args.resolution = 128
model_path = "reswapper-429500.pth" # Hardcoded model path
model = load_model(model_path)
if model is None:
return "Error: Could not load the model. Check the path."
target_face_blob, M = create_target(target_image, args.resolution)
source_latent = create_source(source_image)
swapped_face, _ = swap_face(model, target_face_blob, source_latent)
swapped_face = blend_swapped_image(swapped_face, target_image, M)
return Image.fromarray(swapped_face)
with gr.Blocks() as demo:
target_image = gr.Image(label="Target Image", type="pil")
source_image = gr.Image(label="Source Image", type="pil")
output_image = gr.Image(label="Output Image", type="pil")
btn = gr.Button("Swap Face")
btn.click(fn=process_images, inputs=[target_image, source_image], outputs=output_image)
demo.launch()