Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
import os | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from typing import Optional, Tuple, Union | |
import torch | |
try: | |
from huggingface_hub import ( | |
create_repo, | |
get_hf_file_metadata, | |
hf_hub_download, | |
hf_hub_url, | |
repo_type_and_id_from_hf_id, | |
upload_folder, | |
list_repo_files, | |
) | |
from huggingface_hub.utils import EntryNotFoundError | |
_has_hf_hub = True | |
except ImportError: | |
_has_hf_hub = False | |
try: | |
import safetensors.torch | |
_has_safetensors = True | |
except ImportError: | |
_has_safetensors = False | |
from .factory import create_model_from_pretrained, get_model_config, get_tokenizer | |
from .tokenizer import HFTokenizer | |
# Default name for a weights file hosted on the Huggingface Hub. | |
HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl | |
HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version | |
HF_CONFIG_NAME = 'open_clip_config.json' | |
def save_config_for_hf( | |
model, | |
config_path: str, | |
model_config: Optional[dict] | |
): | |
preprocess_cfg = { | |
'mean': model.visual.image_mean, | |
'std': model.visual.image_std, | |
} | |
hf_config = { | |
'model_cfg': model_config, | |
'preprocess_cfg': preprocess_cfg, | |
} | |
with config_path.open('w') as f: | |
json.dump(hf_config, f, indent=2) | |
def save_for_hf( | |
model, | |
tokenizer: HFTokenizer, | |
model_config: dict, | |
save_directory: str, | |
safe_serialization: Union[bool, str] = False, | |
skip_weights : bool = False, | |
): | |
config_filename = HF_CONFIG_NAME | |
save_directory = Path(save_directory) | |
save_directory.mkdir(exist_ok=True, parents=True) | |
if not skip_weights: | |
tensors = model.state_dict() | |
if safe_serialization is True or safe_serialization == "both": | |
assert _has_safetensors, "`pip install safetensors` to use .safetensors" | |
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) | |
if safe_serialization is False or safe_serialization == "both": | |
torch.save(tensors, save_directory / HF_WEIGHTS_NAME) | |
tokenizer.save_pretrained(save_directory) | |
config_path = save_directory / config_filename | |
save_config_for_hf(model, config_path, model_config=model_config) | |
def push_to_hf_hub( | |
model, | |
tokenizer, | |
model_config: Optional[dict], | |
repo_id: str, | |
commit_message: str = 'Add model', | |
token: Optional[str] = None, | |
revision: Optional[str] = None, | |
private: bool = False, | |
create_pr: bool = False, | |
model_card: Optional[dict] = None, | |
safe_serialization: Union[bool, str] = False, | |
): | |
if not isinstance(tokenizer, HFTokenizer): | |
# default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 | |
tokenizer = HFTokenizer('openai/clip-vit-large-patch14') | |
# Create repo if it doesn't exist yet | |
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) | |
# Infer complete repo_id from repo_url | |
# Can be different from the input `repo_id` if repo_owner was implicit | |
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) | |
repo_id = f"{repo_owner}/{repo_name}" | |
# Check if repo already exists and determine what needs updating | |
repo_exists = False | |
repo_files = {} | |
try: | |
repo_files = set(list_repo_files(repo_id)) | |
repo_exists = True | |
except Exception as e: | |
print('Repo does not exist', e) | |
try: | |
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) | |
has_readme = True | |
except EntryNotFoundError: | |
has_readme = False | |
# Dump model and push to Hub | |
with TemporaryDirectory() as tmpdir: | |
# Save model weights and config. | |
save_for_hf( | |
model, | |
tokenizer=tokenizer, | |
model_config=model_config, | |
save_directory=tmpdir, | |
safe_serialization=safe_serialization, | |
) | |
# Add readme if it does not exist | |
if not has_readme: | |
model_card = model_card or {} | |
model_name = repo_id.split('/')[-1] | |
readme_path = Path(tmpdir) / "README.md" | |
readme_text = generate_readme(model_card, model_name) | |
readme_path.write_text(readme_text) | |
# Upload model and return | |
return upload_folder( | |
repo_id=repo_id, | |
folder_path=tmpdir, | |
revision=revision, | |
create_pr=create_pr, | |
commit_message=commit_message, | |
) | |
def push_pretrained_to_hf_hub( | |
model_name, | |
pretrained: str, | |
repo_id: str, | |
precision: str = 'fp32', | |
image_mean: Optional[Tuple[float, ...]] = None, | |
image_std: Optional[Tuple[float, ...]] = None, | |
commit_message: str = 'Add model', | |
token: Optional[str] = None, | |
revision: Optional[str] = None, | |
private: bool = False, | |
create_pr: bool = False, | |
model_card: Optional[dict] = None, | |
): | |
model, preprocess_eval = create_model_from_pretrained( | |
model_name, | |
pretrained=pretrained, | |
precision=precision, | |
image_mean=image_mean, | |
image_std=image_std, | |
) | |
model_config = get_model_config(model_name) | |
assert model_config | |
tokenizer = get_tokenizer(model_name) | |
push_to_hf_hub( | |
model=model, | |
tokenizer=tokenizer, | |
model_config=model_config, | |
repo_id=repo_id, | |
commit_message=commit_message, | |
token=token, | |
revision=revision, | |
private=private, | |
create_pr=create_pr, | |
model_card=model_card, | |
safe_serialization='both', | |
) | |
def generate_readme(model_card: dict, model_name: str): | |
readme_text = "---\n" | |
readme_text += "tags:\n- clip\n" | |
readme_text += "library_name: open_clip\n" | |
readme_text += "pipeline_tag: zero-shot-image-classification\n" | |
readme_text += f"license: {model_card.get('license', 'mit')}\n" | |
if 'details' in model_card and 'Dataset' in model_card['details']: | |
readme_text += 'datasets:\n' | |
readme_text += f"- {model_card['details']['Dataset'].lower()}\n" | |
readme_text += "---\n" | |
readme_text += f"# Model card for {model_name}\n" | |
if 'description' in model_card: | |
readme_text += f"\n{model_card['description']}\n" | |
if 'details' in model_card: | |
readme_text += f"\n## Model Details\n" | |
for k, v in model_card['details'].items(): | |
if isinstance(v, (list, tuple)): | |
readme_text += f"- **{k}:**\n" | |
for vi in v: | |
readme_text += f" - {vi}\n" | |
elif isinstance(v, dict): | |
readme_text += f"- **{k}:**\n" | |
for ki, vi in v.items(): | |
readme_text += f" - {ki}: {vi}\n" | |
else: | |
readme_text += f"- **{k}:** {v}\n" | |
if 'usage' in model_card: | |
readme_text += f"\n## Model Usage\n" | |
readme_text += model_card['usage'] | |
readme_text += '\n' | |
if 'comparison' in model_card: | |
readme_text += f"\n## Model Comparison\n" | |
readme_text += model_card['comparison'] | |
readme_text += '\n' | |
if 'citation' in model_card: | |
readme_text += f"\n## Citation\n" | |
if not isinstance(model_card['citation'], (list, tuple)): | |
citations = [model_card['citation']] | |
else: | |
citations = model_card['citation'] | |
for c in citations: | |
readme_text += f"```bibtex\n{c}\n```\n" | |
return readme_text | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") | |
parser.add_argument( | |
"--model", type=str, help="Name of the model to use.", | |
) | |
parser.add_argument( | |
"--pretrained", type=str, | |
help="Use a pretrained CLIP model weights with the specified tag or file path.", | |
) | |
parser.add_argument( | |
"--repo-id", type=str, | |
help="Destination HF Hub repo-id ie 'organization/model_id'.", | |
) | |
parser.add_argument( | |
"--precision", type=str, default='fp32', | |
) | |
parser.add_argument( | |
'--image-mean', type=float, nargs='+', default=None, metavar='MEAN', | |
help='Override default image mean value of dataset') | |
parser.add_argument( | |
'--image-std', type=float, nargs='+', default=None, metavar='STD', | |
help='Override default image std deviation of of dataset') | |
args = parser.parse_args() | |
print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') | |
# FIXME add support to pass model_card json / template from file via cmd line | |
push_pretrained_to_hf_hub( | |
args.model, | |
args.pretrained, | |
args.repo_id, | |
precision=args.precision, | |
image_mean=args.image_mean, # override image mean/std if trained w/ non defaults | |
image_std=args.image_std, | |
) | |
print(f'{args.model} saved.') | |