Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import logging | |
import torch | |
import asyncio | |
import aiohttp | |
import requests | |
from huggingface_hub import hf_hub_download | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Configuration | |
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data') | |
MODELS_DIR = os.path.join(DATA_ROOT, "models") | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Hugging Face repository information | |
HF_REPO_ID = "jbilcke-hf/model-cocktail" | |
# Model files to download | |
MODEL_FILES = [ | |
"dwpose/dw-ll_ucoco_384.pth", | |
"face-detector/s3fd-619a316812.pth", | |
"liveportrait/spade_generator.pth", | |
"liveportrait/warping_module.pth", | |
"liveportrait/motion_extractor.pth", | |
"liveportrait/stitching_retargeting_module.pth", | |
"liveportrait/appearance_feature_extractor.pth", | |
"liveportrait/landmark.onnx", | |
# For animal mode πΆπ± | |
# however they say animal mode doesn't support stitching yet? | |
# https://github.com/KwaiVGI/LivePortrait/blob/main/assets/docs/changelog/2024-08-02.md#updates-on-animals-mode | |
#"liveportrait-animals/warping_module.pth", | |
#"liveportrait-animals/spade_generator.pth", | |
#"liveportrait-animals/motion_extractor.pth", | |
#"liveportrait-animals/appearance_feature_extractor.pth", | |
#"liveportrait-animals/stitching_retargeting_module.pth", | |
#"liveportrait-animals/xpose.pth", | |
# this is a hack, instead we should probably try to | |
# fix liveportrait/utils/dependencies/insightface/utils/storage.py | |
"insightface/models/buffalo_l.zip", | |
"insightface/buffalo_l/det_10g.onnx", | |
"insightface/buffalo_l/2d106det.onnx", | |
"sd-vae-ft-mse/diffusion_pytorch_model.bin", | |
"sd-vae-ft-mse/diffusion_pytorch_model.safetensors", | |
"sd-vae-ft-mse/config.json", | |
# we don't use those yet | |
#"flux-dev/flux-dev-fp8.safetensors", | |
#"flux-dev/flux_dev_quantization_map.json", | |
#"pulid-flux/pulid_flux_v0.9.0.safetensors", | |
#"pulid-flux/pulid_v1.bin" | |
] | |
def create_directory(directory): | |
"""Create a directory if it doesn't exist and log its status.""" | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
logger.info(f" Directory created: {directory}") | |
else: | |
logger.info(f" Directory already exists: {directory}") | |
def print_directory_structure(startpath): | |
"""Print the directory structure starting from the given path.""" | |
for root, dirs, files in os.walk(startpath): | |
level = root.replace(startpath, '').count(os.sep) | |
indent = ' ' * 4 * level | |
logger.info(f"{indent}{os.path.basename(root)}/") | |
subindent = ' ' * 4 * (level + 1) | |
for f in files: | |
logger.info(f"{subindent}{f}") | |
async def download_hf_file(filename: str) -> None: | |
"""Download a file from Hugging Face to the models directory.""" | |
dest = os.path.join(MODELS_DIR, filename) | |
os.makedirs(os.path.dirname(dest), exist_ok=True) | |
if os.path.exists(dest): | |
# this is really for debugging purposes only | |
logger.debug(f" β {filename}") | |
return | |
logger.info(f" β³ Downloading {HF_REPO_ID}/{filename}") | |
try: | |
await asyncio.get_event_loop().run_in_executor( | |
None, | |
lambda: hf_hub_download( | |
repo_id=HF_REPO_ID, | |
filename=filename, | |
local_dir=MODELS_DIR | |
) | |
) | |
logger.info(f" β Downloaded {filename}") | |
except Exception as e: | |
logger.error(f"π¨ Error downloading file from Hugging Face: {e}") | |
if os.path.exists(dest): | |
os.remove(dest) | |
raise | |
async def download_all_models(): | |
"""Download all required models from the Hugging Face repository.""" | |
logger.info(" π Looking for models...") | |
tasks = [download_hf_file(filename) for filename in MODEL_FILES] | |
await asyncio.gather(*tasks) | |
logger.info(" β All models are available") | |
# are you looking to debug the app and verify that models are downloaded properly? | |
# then un-comment the two following lines: | |
#logger.info("π‘ Printing directory structure of models:") | |
#print_directory_structure(MODELS_DIR) | |
class ModelLoader: | |
"""A class responsible for loading and initializing all required models.""" | |
def __init__(self): | |
self.device = DEVICE | |
self.models_dir = MODELS_DIR | |
async def load_live_portrait(self): | |
"""Load LivePortrait models.""" | |
from liveportrait.config.inference_config import InferenceConfig | |
from liveportrait.config.crop_config import CropConfig | |
from liveportrait.live_portrait_pipeline import LivePortraitPipeline | |
logger.info(" β³ Loading LivePortrait models...") | |
live_portrait_pipeline = await asyncio.to_thread( | |
LivePortraitPipeline, | |
inference_cfg=InferenceConfig( | |
# default values | |
flag_stitching=True, # we recommend setting it to True! | |
flag_relative=True, # whether to use relative motion | |
flag_pasteback=True, # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space | |
flag_do_crop= True, # whether to crop the source portrait to the face-cropping space | |
flag_do_rot=True, # whether to conduct the rotation when flag_do_crop is True | |
), | |
crop_cfg=CropConfig() | |
) | |
logger.info(" β LivePortrait models loaded successfully.") | |
return live_portrait_pipeline | |
async def initialize_models(): | |
"""Initialize and load all required models.""" | |
logger.info("π Starting model initialization...") | |
# Ensure all required models are downloaded | |
await download_all_models() | |
# Initialize the ModelLoader | |
loader = ModelLoader() | |
# Load LivePortrait models | |
live_portrait = await loader.load_live_portrait() | |
logger.info("β Model initialization completed.") | |
return live_portrait | |
# Initial setup | |
logger.info("π Setting up storage directories...") | |
create_directory(MODELS_DIR) | |
logger.info("β Storage directories setup completed.") | |