Spaces:
Runtime error
Runtime error
| import uuid | |
| import logging | |
| import hashlib | |
| import os | |
| import io | |
| import asyncio | |
| from async_lru import alru_cache | |
| import base64 | |
| from queue import Queue | |
| from typing import Dict, Any, List, Optional, Union | |
| from functools import lru_cache | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from liveportrait.config.argument_config import ArgumentConfig | |
| from liveportrait.utils.camera import get_rotation_matrix | |
| from liveportrait.utils.io import resize_to_limit | |
| from liveportrait.utils.crop import prepare_paste_back, paste_back, parse_bbox_from_landmark | |
| # Configure logging | |
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Global constants | |
| DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') | |
| MODELS_DIR = os.path.join(DATA_ROOT, "models") | |
| def base64_data_uri_to_PIL_Image(base64_string: str) -> Image.Image: | |
| """ | |
| Convert a base64 data URI to a PIL Image. | |
| Args: | |
| base64_string (str): The base64 encoded image data. | |
| Returns: | |
| Image.Image: The decoded PIL Image. | |
| """ | |
| if ',' in base64_string: | |
| base64_string = base64_string.split(',')[1] | |
| img_data = base64.b64decode(base64_string) | |
| return Image.open(io.BytesIO(img_data)) | |
| class Engine: | |
| """ | |
| The main engine class for FacePoke | |
| """ | |
| def __init__(self, live_portrait): | |
| """ | |
| Initialize the FacePoke engine with necessary models and processors. | |
| Args: | |
| live_portrait (LivePortraitPipeline): The LivePortrait model for video generation. | |
| """ | |
| self.live_portrait = live_portrait | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.processed_cache = {} # Stores the processed image data | |
| logger.info("β FacePoke Engine initialized successfully.") | |
| async def load_image(self, data): | |
| image = Image.open(io.BytesIO(data)) | |
| uid = str(uuid.uuid4()) | |
| img_rgb = np.array(image) | |
| inference_cfg = self.live_portrait.live_portrait_wrapper.cfg | |
| img_rgb = await asyncio.to_thread(resize_to_limit, img_rgb, inference_cfg.ref_max_shape, inference_cfg.ref_shape_n) | |
| crop_info = await asyncio.to_thread(self.live_portrait.cropper.crop_single_image, img_rgb) | |
| img_crop_256x256 = crop_info['img_crop_256x256'] | |
| I_s = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.prepare_source, img_crop_256x256) | |
| x_s_info = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.get_kp_info, I_s) | |
| f_s = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.extract_feature_3d, I_s) | |
| x_s = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.transform_keypoint, x_s_info) | |
| processed_data = { | |
| 'img_rgb': img_rgb, | |
| 'crop_info': crop_info, | |
| 'x_s_info': x_s_info, | |
| 'f_s': f_s, | |
| 'x_s': x_s, | |
| 'inference_cfg': inference_cfg | |
| } | |
| self.processed_cache[uid] = processed_data | |
| # Calculate the bounding box | |
| bbox_info = parse_bbox_from_landmark(processed_data['crop_info']['lmk_crop'], scale=1.0) | |
| return { | |
| 'u': uid, | |
| # those aren't easy to serialize | |
| 'c': bbox_info['center'], # 2x1 | |
| 's': bbox_info['size'], # scalar | |
| 'b': bbox_info['bbox'], # 4x2 | |
| 'a': bbox_info['angle'], # rad, counterclockwise | |
| # 'bbox_rot': bbox_info['bbox_rot'].toList(), # 4x2 | |
| } | |
| async def transform_image(self, uid: str, params: Dict[str, float]) -> bytes: | |
| # If we don't have the image in cache yet, add it | |
| if uid not in self.processed_cache: | |
| raise ValueError("cache miss") | |
| processed_data = self.processed_cache[uid] | |
| try: | |
| # Apply modifications based on params | |
| x_d_new = processed_data['x_s_info']['kp'].clone() | |
| modifications = [ | |
| ('smile', [ | |
| (0, 20, 1, -0.01), (0, 14, 1, -0.02), (0, 17, 1, 0.0065), (0, 17, 2, 0.003), | |
| (0, 13, 1, -0.00275), (0, 16, 1, -0.00275), (0, 3, 1, -0.0035), (0, 7, 1, -0.0035) | |
| ]), | |
| ('aaa', [ | |
| (0, 19, 1, 0.001), (0, 19, 2, 0.0001), (0, 17, 1, -0.0001) | |
| ]), | |
| ('eee', [ | |
| (0, 20, 2, -0.001), (0, 20, 1, -0.001), (0, 14, 1, -0.001) | |
| ]), | |
| ('woo', [ | |
| (0, 14, 1, 0.001), (0, 3, 1, -0.0005), (0, 7, 1, -0.0005), (0, 17, 2, -0.0005) | |
| ]), | |
| ('wink', [ | |
| (0, 11, 1, 0.001), (0, 13, 1, -0.0003), (0, 17, 0, 0.0003), | |
| (0, 17, 1, 0.0003), (0, 3, 1, -0.0003) | |
| ]), | |
| ('pupil_x', [ | |
| (0, 11, 0, 0.0007 if params.get('pupil_x', 0) > 0 else 0.001), | |
| (0, 15, 0, 0.001 if params.get('pupil_x', 0) > 0 else 0.0007) | |
| ]), | |
| ('pupil_y', [ | |
| (0, 11, 1, -0.001), (0, 15, 1, -0.001) | |
| ]), | |
| ('eyes', [ | |
| (0, 11, 1, -0.001), (0, 13, 1, 0.0003), (0, 15, 1, -0.001), (0, 16, 1, 0.0003), | |
| (0, 1, 1, -0.00025), (0, 2, 1, 0.00025) | |
| ]), | |
| ('eyebrow', [ | |
| (0, 1, 1, 0.001 if params.get('eyebrow', 0) > 0 else 0.0003), | |
| (0, 2, 1, -0.001 if params.get('eyebrow', 0) > 0 else -0.0003), | |
| (0, 1, 0, -0.001 if params.get('eyebrow', 0) <= 0 else 0), | |
| (0, 2, 0, 0.001 if params.get('eyebrow', 0) <= 0 else 0) | |
| ]) | |
| ] | |
| for param_name, adjustments in modifications: | |
| param_value = params.get(param_name, 0) | |
| for i, j, k, factor in adjustments: | |
| x_d_new[i, j, k] += param_value * factor | |
| # Special case for pupil_y affecting eyes | |
| x_d_new[0, 11, 1] -= params.get('pupil_y', 0) * 0.001 | |
| x_d_new[0, 15, 1] -= params.get('pupil_y', 0) * 0.001 | |
| params['eyes'] = params.get('eyes', 0) - params.get('pupil_y', 0) / 2. | |
| # Apply rotation | |
| R_new = get_rotation_matrix( | |
| processed_data['x_s_info']['pitch'] + params.get('rotate_pitch', 0), | |
| processed_data['x_s_info']['yaw'] + params.get('rotate_yaw', 0), | |
| processed_data['x_s_info']['roll'] + params.get('rotate_roll', 0) | |
| ) | |
| x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t'] | |
| # Apply stitching | |
| x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new) | |
| # Generate the output | |
| out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new) | |
| I_p = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.parse_output, out['out']) | |
| buffered = io.BytesIO() | |
| #################################################### | |
| # this part is about stitching the image back into the original. | |
| # | |
| # this is an expensive operation, not just because of the compute | |
| # but because the payload will also be bigger (we send back the whole pic) | |
| # | |
| # I'm currently running some experiments to do it in the frontend | |
| # | |
| # --- old way: we do it in the server-side: --- | |
| mask_ori = await asyncio.to_thread(prepare_paste_back, | |
| processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'], | |
| dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0]) | |
| ) | |
| I_p_to_ori_blend = await asyncio.to_thread(paste_back, | |
| I_p[0], processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori | |
| ) | |
| result_image = Image.fromarray(I_p_to_ori_blend) | |
| # --- maybe future way: do it in the frontend: --- | |
| #result_image = Image.fromarray(I_p[0]) | |
| #################################################### | |
| # write it into a webp | |
| result_image.save(buffered, format="WebP", quality=82, lossless=False, method=6) | |
| return buffered.getvalue() | |
| except Exception as e: | |
| raise ValueError(f"Failed to modify image: {str(e)}") | |