Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import requests | |
from rich.console import Console | |
from tqdm import tqdm | |
import subprocess | |
import sys | |
try: | |
import folder_paths | |
except ModuleNotFoundError: | |
import sys | |
sys.path.append(os.path.join(os.path.dirname(__file__), "../../..")) | |
import folder_paths | |
models_to_download = { | |
"DeepBump": { | |
"size": 25.5, | |
"download_url": "https://github.com/HugoTini/DeepBump/raw/master/deepbump256.onnx", | |
"destination": "deepbump", | |
}, | |
"Face Swap": { | |
"size": 660, | |
"download_url": [ | |
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth", | |
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", | |
"https://huggingface.co/deepinsight/inswapper/resolve/main/inswapper_128.onnx", | |
], | |
"destination": "insightface", | |
}, | |
"GFPGAN (face enhancement)": { | |
"size": 332, | |
"download_url": [ | |
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", | |
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" | |
# TODO: provide a way to selectively download models from "packs" | |
# https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth | |
# https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth | |
# https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth | |
], | |
"destination": "face_restore", | |
}, | |
"FILM: Frame Interpolation for Large Motion": { | |
"size": 402, | |
"download_url": [ | |
"https://drive.google.com/drive/folders/131_--QrieM4aQbbLWrUtbO2cGbX8-war" | |
], | |
"destination": "FILM", | |
}, | |
} | |
console = Console() | |
from urllib.parse import urlparse | |
from pathlib import Path | |
def download_model(download_url, destination): | |
if isinstance(download_url, list): | |
for url in download_url: | |
download_model(url, destination) | |
return | |
filename = os.path.basename(urlparse(download_url).path) | |
response = None | |
if "drive.google.com" in download_url: | |
try: | |
import gdown | |
except ImportError: | |
print("Installing gdown") | |
subprocess.check_call( | |
[ | |
sys.executable, | |
"-m", | |
"pip", | |
"install", | |
"git+https://github.com/melMass/gdown@main", | |
] | |
) | |
import gdown | |
if "/folders/" in download_url: | |
# download folder | |
try: | |
gdown.download_folder(download_url, output=destination, resume=True) | |
except TypeError: | |
gdown.download_folder(download_url, output=destination) | |
return | |
# download from google drive | |
gdown.download(download_url, destination, quiet=False, resume=True) | |
return | |
response = requests.get(download_url, stream=True) | |
total_size = int(response.headers.get("content-length", 0)) | |
destination_path = os.path.join(destination, filename) | |
with open(destination_path, "wb") as file: | |
with tqdm( | |
total=total_size, unit="B", unit_scale=True, desc=destination_path, ncols=80 | |
) as progress_bar: | |
for data in response.iter_content(chunk_size=4096): | |
file.write(data) | |
progress_bar.update(len(data)) | |
console.print( | |
f"Downloaded model from {download_url} to {destination_path}", | |
style="bold green", | |
) | |
def ask_user_for_downloads(models_to_download): | |
console.print("Choose models to download:") | |
choices = {} | |
for i, model_name in enumerate(models_to_download.keys(), start=1): | |
choices[str(i)] = model_name | |
console.print(f"{i}. {model_name}") | |
console.print( | |
"Enter the numbers of the models you want to download (comma-separated):" | |
) | |
user_input = console.input(">> ") | |
selected_models = user_input.split(",") | |
models_to_download_selected = {} | |
for choice in selected_models: | |
choice = choice.strip() | |
if choice in choices: | |
model_name = choices[choice] | |
models_to_download_selected[model_name] = models_to_download[model_name] | |
elif choice == "": | |
# download all | |
models_to_download_selected = models_to_download | |
else: | |
console.print(f"Invalid choice: {choice}. Skipping.") | |
return models_to_download_selected | |
def handle_interrupt(): | |
console.print("Interrupted by user.", style="bold red") | |
def main(models_to_download, skip_input=False): | |
try: | |
models_to_download_selected = {} | |
def check_destination(urls, destination): | |
if isinstance(urls, list): | |
for url in urls: | |
check_destination(url, destination) | |
return | |
filename = os.path.basename(urlparse(urls).path) | |
destination = os.path.join(folder_paths.models_dir, destination) | |
if not os.path.exists(destination): | |
os.makedirs(destination) | |
destination_path = os.path.join(destination, filename) | |
if os.path.exists(destination_path): | |
url_name = os.path.basename(urlparse(urls).path) | |
console.print( | |
f"Checkpoint '{url_name}' for {model_name} already exists in '{destination}'" | |
) | |
else: | |
model_details["destination"] = destination | |
models_to_download_selected[model_name] = model_details | |
for model_name, model_details in models_to_download.items(): | |
destination = model_details["destination"] | |
download_url = model_details["download_url"] | |
check_destination(download_url, destination) | |
if not models_to_download_selected: | |
console.print("No new models to download.") | |
return | |
models_to_download_selected = ( | |
ask_user_for_downloads(models_to_download_selected) | |
if not skip_input | |
else models_to_download_selected | |
) | |
for model_name, model_details in models_to_download_selected.items(): | |
download_url = model_details["download_url"] | |
destination = model_details["destination"] | |
console.print(f"Downloading {model_name}...") | |
download_model(download_url, destination) | |
except KeyboardInterrupt: | |
handle_interrupt() | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-y", "--yes", action="store_true", help="skip user input") | |
args = parser.parse_args() | |
main(models_to_download, args.yes) | |