| import os | |
| import urllib.request | |
| def download_models(): | |
| ED_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth" | |
| VAE_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth" | |
| ED_MODEL_PATH = "./pretrained_models/genconvit_ed_inference.pth" | |
| VAE_MODEL_PATH = "./pretrained_models/genconvit_vae_inference.pth" | |
| os.makedirs("pretrained_models", exist_ok=True) | |
| def progress(block_num, block_size, total_size): | |
| progress_amount = block_num * block_size | |
| if total_size > 0: | |
| percent = (progress_amount / total_size) * 100 | |
| print(f"Downloading... {percent:.2f}%") | |
| if not os.path.isfile(ED_MODEL_PATH): | |
| print("Downloading ED model") | |
| urllib.request.urlretrieve(ED_MODEL_URL, ED_MODEL_PATH, reporthook=progress) | |
| if not os.path.isfile(VAE_MODEL_PATH): | |
| print("Downloading VAE model") | |
| urllib.request.urlretrieve(VAE_MODEL_URL, VAE_MODEL_PATH, reporthook=progress) | |
| download_models() |