File size: 1,051 Bytes
9c4b01e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
import urllib.request

def download_models():
    ED_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth"
    VAE_MODEL_URL = "https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth"

    ED_MODEL_PATH = "./pretrained_models/genconvit_ed_inference.pth"
    VAE_MODEL_PATH = "./pretrained_models/genconvit_vae_inference.pth"

    os.makedirs("pretrained_models", exist_ok=True)

    def progress(block_num, block_size, total_size):
        progress_amount = block_num * block_size
        if total_size > 0:
            percent = (progress_amount / total_size) * 100
            print(f"Downloading... {percent:.2f}%")

    if not os.path.isfile(ED_MODEL_PATH):
        print("Downloading ED model")
        urllib.request.urlretrieve(ED_MODEL_URL, ED_MODEL_PATH, reporthook=progress)

    if not os.path.isfile(VAE_MODEL_PATH):
        print("Downloading VAE model")
        urllib.request.urlretrieve(VAE_MODEL_URL, VAE_MODEL_PATH, reporthook=progress)

download_models()