import os | |
import urllib.request | |
def download_models(): | |
ED_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth" | |
VAE_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth" | |
ED_MODEL_PATH = "./pretrained_models/genconvit_ed_inference.pth" | |
VAE_MODEL_PATH = "./pretrained_models/genconvit_vae_inference.pth" | |
os.makedirs("pretrained_models", exist_ok=True) | |
def progress(block_num, block_size, total_size): | |
progress_amount = block_num * block_size | |
if total_size > 0: | |
percent = (progress_amount / total_size) * 100 | |
print(f"Downloading... {percent:.2f}%") | |
if not os.path.isfile(ED_MODEL_PATH): | |
print("Downloading ED model") | |
urllib.request.urlretrieve(ED_MODEL_URL, ED_MODEL_PATH, reporthook=progress) | |
if not os.path.isfile(VAE_MODEL_PATH): | |
print("Downloading VAE model") | |
urllib.request.urlretrieve(VAE_MODEL_URL, VAE_MODEL_PATH, reporthook=progress) | |
download_models() |