|
from pathlib import Path |
|
import torch |
|
import urllib |
|
import requests |
|
import subprocess |
|
|
|
def attempt_download_from_hub(repo_id, hf_token=None): |
|
|
|
from huggingface_hub import hf_hub_download, list_repo_files |
|
from huggingface_hub.utils._errors import RepositoryNotFoundError |
|
from huggingface_hub.utils._validators import HFValidationError |
|
try: |
|
repo_files = list_repo_files(repo_id=repo_id, repo_type='model', token=hf_token) |
|
model_file = [f for f in repo_files if f.endswith('.pth')][0] |
|
file = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=model_file, |
|
repo_type='model', |
|
token=hf_token, |
|
) |
|
return file |
|
except (RepositoryNotFoundError, HFValidationError): |
|
return None |
|
|
|
|
|
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): |
|
import os |
|
|
|
|
|
file = Path(file) |
|
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" |
|
try: |
|
torch.hub.download_url_to_file(url, str(file), progress=True) |
|
assert file.exists() and file.stat().st_size > min_bytes, assert_msg |
|
except Exception as e: |
|
file.unlink(missing_ok=True) |
|
os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") |
|
finally: |
|
if not file.exists() or file.stat().st_size < min_bytes: |
|
file.unlink(missing_ok=True) |
|
raise Exception(error_msg or assert_msg) |
|
|
|
def attempt_download(file, repo='Megvii-BaseDetection/YOLOX', release='0.1.0'): |
|
def github_assets(repository, version='latest'): |
|
response = requests.get(f'https://api.github.com/repos/{repository}/releases/tags/{version}').json() |
|
return response['tag_name'], [x['name'] for x in response['assets']] |
|
|
|
file = Path(str(file).strip().replace("'", '')) |
|
if not file.exists(): |
|
|
|
name = Path(urllib.parse.unquote(str(file))).name |
|
if str(file).startswith(('http:/', 'https:/')): |
|
url = str(file).replace(':/', '://') |
|
file = name.split('?')[0] |
|
if Path(file).is_file(): |
|
return file |
|
else: |
|
safe_download(file=file, url=url, min_bytes=1E5) |
|
return file |
|
|
|
|
|
assets = [ |
|
'yolov6n.pt', 'yolov6s.pt', 'yolov6m.pt', 'yolov6l.pt', |
|
'yolov6n6.pt', 'yolov6s6.pt', 'yolov6m6.pt', 'yolov6l6.pt'] |
|
try: |
|
tag, assets = github_assets(repo, release) |
|
except Exception: |
|
try: |
|
tag, assets = github_assets(repo) |
|
except Exception: |
|
try: |
|
tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1] |
|
except Exception: |
|
tag = release |
|
|
|
file.parent.mkdir(parents=True, exist_ok=True) |
|
if name in assets: |
|
safe_download( |
|
file, |
|
url=f'https://github.com/{repo}/releases/download/{tag}/{name}', |
|
url2=f'https://storage.googleapis.com/{repo}/{tag}/{name}', |
|
min_bytes=1E5, |
|
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}') |
|
|
|
return str(file) |