moistdio's picture
Upload folder using huggingface_hub
6831a54 verified
import torch
import os
import time
import httpx
import warnings
import gradio.networking
import safetensors.torch
from tqdm import tqdm
def gradio_url_ok_fix(url: str) -> bool:
try:
for _ in range(5):
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
r = httpx.head(url, timeout=999, verify=False)
if r.status_code in (200, 401, 302):
return True
time.sleep(0.500)
except (ConnectionError, httpx.ConnectError):
return False
return False
def build_loaded(module, loader_name):
original_loader_name = loader_name + '_origin'
if not hasattr(module, original_loader_name):
setattr(module, original_loader_name, getattr(module, loader_name))
original_loader = getattr(module, original_loader_name)
def loader(*args, **kwargs):
result = None
try:
result = original_loader(*args, **kwargs)
except Exception as e:
result = None
exp = str(e) + '\n'
for path in list(args) + list(kwargs.values()):
if isinstance(path, str):
if os.path.exists(path):
exp += f'File corrupted: {path} \n'
corrupted_backup_file = path + '.corrupted'
if os.path.exists(corrupted_backup_file):
os.remove(corrupted_backup_file)
os.replace(path, corrupted_backup_file)
if os.path.exists(path):
os.remove(path)
exp += f'Forge has tried to move the corrupted file to {corrupted_backup_file} \n'
exp += f'You may try again now and Forge will download models again. \n'
raise ValueError(exp)
return result
setattr(module, loader_name, loader)
return
def always_show_tqdm(*args, **kwargs):
kwargs['disable'] = False
if 'name' in kwargs:
del kwargs['name']
return tqdm(*args, **kwargs)
def patch_all_basics():
import logging
from huggingface_hub import file_download
file_download.tqdm = always_show_tqdm
from transformers.dynamic_module_utils import logger
logger.setLevel(logging.ERROR)
gradio.networking.url_ok = gradio_url_ok_fix
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')
return