mpt-7b / model_download_utils.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
3ff9962 verified
"""Utility functions for downloading models."""
import copy
import logging
import os
import shutil
import subprocess
import time
import warnings
from http import HTTPStatus
from typing import Optional
from urllib.parse import urljoin
import huggingface_hub as hf_hub
import requests
import tenacity
import yaml
from bs4 import BeautifulSoup
from requests.packages.urllib3.exceptions import InsecureRequestWarning
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
DEFAULT_IGNORE_PATTERNS = ['*.ckpt', '*.h5', '*.msgpack']
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
TOKENIZER_FILES = ['special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'tokenizer_config.json']
ORAS_PASSWD_PLACEHOLDER = '<placeholder_for_passwd>'
ORAS_CLI = 'oras'
log = logging.getLogger(__name__)
@tenacity.retry(retry=tenacity.retry_if_not_exception_type((ValueError, hf_hub.utils.RepositoryNotFoundError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10))
def download_from_hf_hub(model: str, save_dir: str, prefer_safetensors: bool=True, tokenizer_only: bool=False, token: Optional[str]=None):
"""Downloads model files from a Hugging Face Hub model repo.
Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the
Safetensors weights will be downloaded unless `prefer_safetensors` is set to False.
Args:
repo_id (str): The Hugging Face Hub repo ID.
save_dir (str, optional): The local path to the directory where the model files will be downloaded.
prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
available. Defaults to True.
tokenizer_only (bool): If true, only download tokenizer files.
token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
`HUGGING_FACE_HUB_TOKEN` environment variable.
Raises:
RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized.
ValueError: If the model repo doesn't contain any supported model weights.
"""
repo_files = set(hf_hub.list_repo_files(model))
ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS)
safetensors_available = SAFE_WEIGHTS_NAME in repo_files or SAFE_WEIGHTS_INDEX_NAME in repo_files
pytorch_available = PYTORCH_WEIGHTS_NAME in repo_files or PYTORCH_WEIGHTS_INDEX_NAME in repo_files
if safetensors_available and pytorch_available:
if prefer_safetensors:
log.info('Safetensors available and preferred. Excluding pytorch weights.')
ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN)
else:
log.info('Pytorch available and preferred. Excluding safetensors weights.')
ignore_patterns.append(SAFE_WEIGHTS_PATTERN)
elif safetensors_available:
log.info('Only safetensors available. Ignoring weights preference.')
elif pytorch_available:
log.info('Only pytorch available. Ignoring weights preference.')
else:
raise ValueError(f'No supported model weights found in repo {model}.' + ' Please make sure the repo contains either safetensors or pytorch weights.')
allow_patterns = TOKENIZER_FILES if tokenizer_only else None
download_start = time.time()
hf_hub.snapshot_download(model, local_dir=save_dir, local_dir_use_symlinks=False, ignore_patterns=ignore_patterns, allow_patterns=allow_patterns, token=token)
download_duration = time.time() - download_start
log.info(f'Downloaded model {model} from Hugging Face Hub in {download_duration} seconds')
def _extract_links_from_html(html: str):
"""Extracts links from HTML content.
Args:
html (str): The HTML content
Returns:
list[str]: A list of links to download.
"""
soup = BeautifulSoup(html, 'html.parser')
links = [a['href'] for a in soup.find_all('a')]
return links
def _recursive_download(session: requests.Session, base_url: str, path: str, save_dir: str, ignore_cert: bool=False):
"""Downloads all files/subdirectories from a directory on a remote server.
Args:
session: A requests.Session through which to make requests to the remote server.
url (str): The base URL where the files are located.
path (str): The path from the base URL to the files to download. The full URL for the download is equal to
'<base_url>/<path>'.
save_dir (str): The directory to save downloaded files to.
ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
Defaults to False.
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
Raises:
PermissionError: If the remote server returns a 401 Unauthorized status code.
ValueError: If the remote server returns a 404 Not Found status code.
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""
url = urljoin(base_url, path)
print(url)
response = session.get(url, verify=not ignore_cert)
if response.status_code == HTTPStatus.UNAUTHORIZED:
raise PermissionError(f'Not authorized to download file from {url}. Received status code {response.status_code}. ')
elif response.status_code == HTTPStatus.NOT_FOUND:
raise ValueError(f'Could not find file at {url}. Received status code {response.status_code}')
elif response.status_code != HTTPStatus.OK:
raise RuntimeError(f'Could not download file from {url}. Received unexpected status code {response.status_code}')
if not url.endswith('/'):
save_path = os.path.join(save_dir, path)
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)
with open(save_path, 'wb') as f:
f.write(response.content)
log.info(f'Downloaded file {save_path}')
return
child_links = _extract_links_from_html(response.content.decode())
print(child_links)
for child_link in child_links:
_recursive_download(session, base_url, urljoin(path, child_link), save_dir, ignore_cert=ignore_cert)
@tenacity.retry(retry=tenacity.retry_if_not_exception_type((PermissionError, ValueError)), stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(min=1, max=10))
def download_from_http_fileserver(url: str, save_dir: str, ignore_cert: bool=False):
"""Downloads files from a remote HTTP file server.
Args:
url (str): The base URL where the files are located.
save_dir (str): The directory to save downloaded files to.
ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
Defaults to False.
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
"""
with requests.Session() as session:
with warnings.catch_warnings():
if ignore_cert:
warnings.simplefilter('ignore', category=InsecureRequestWarning)
_recursive_download(session, url, '', save_dir, ignore_cert=ignore_cert)
def download_from_oras(model: str, config_file: str, credentials_dir: str, save_dir: str, tokenizer_only: bool=False, concurrency: int=10):
"""Download from an OCI-compliant registry using oras.
Args:
model (str): The name of the model to download.
config_file (str): Path to a YAML config file that maps model and tokenizer names to registry paths.
credentials_dir (str): Path to a directory containing credentials for the registry. It is expected to contain three
files: `username`, `password`, and `registry`, each of which contains the corresponding credential.
save_dir (str): Path to the directory where files will be downloaded.
tokenizer_only (bool): If true, only download the tokenzier files.
concurrency (int): The number of concurrent downloads to run.
"""
if shutil.which(ORAS_CLI) is None:
raise Exception(f'oras cli command `{ORAS_CLI}` is not found. Please install oras: https://oras.land/docs/installation ')
def _read_secrets_file(secret_file_path: str):
try:
with open(secret_file_path, encoding='utf-8') as f:
return f.read().strip()
except Exception as error:
raise ValueError(f'secrets file {secret_file_path} failed to be read') from error
secrets = {}
for secret in ['username', 'password', 'registry']:
secrets[secret] = _read_secrets_file(os.path.join(credentials_dir, secret))
with open(config_file, 'r', encoding='utf-8') as f:
configs = yaml.safe_load(f.read())
config_type = 'tokenizers' if tokenizer_only else 'models'
path = configs[config_type][model]
registry = secrets['registry']
def get_oras_cmd(username: Optional[str]=None, password: Optional[str]=None):
cmd = [ORAS_CLI, 'pull', f'{registry}/{path}', '-o', save_dir, '--verbose', '--concurrency', str(concurrency)]
if username is not None:
cmd.extend(['--username', username])
if password is not None:
cmd.extend(['--password', password])
return cmd
cmd_without_creds = get_oras_cmd()
log.info(f"CMD for oras cli to run: {' '.join(cmd_without_creds)}")
cmd_to_run = get_oras_cmd(username=secrets['username'], password=secrets['password'])
try:
subprocess.run(cmd_to_run, check=True)
except subprocess.CalledProcessError as e:
raise subprocess.CalledProcessError(e.returncode, cmd_without_creds, e.output, e.stderr)