Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
from __future__ import annotations | |
import shutil | |
import tempfile | |
import gradio as gr | |
from huggingface_hub import HfApi | |
from huggingface_hub import create_repo | |
from huggingface_hub.repocard import RepoCard | |
title = 'dreambooth space creator' | |
description = ''' | |
With this Space, you can create a dreambooth gradio demos for models that are loadable with `gradio.Interface.load` in [Model Hub](https://huggingface.co/models). | |
The Space will be created under your account and private. | |
You need a token with write permission (See: https://huggingface.co/settings/tokens). | |
You can specify multiple model names by listing them separated by commas. | |
If you specify multiple model names, the resulting Space will show all the outputs of those models side by side for the given inputs. | |
''' | |
article = '' | |
examples = [ | |
[ | |
'Dungeons-and-Diffusion', | |
'0xJustin/Dungeons-and-Diffusion', | |
'', | |
'Demo for 0xJustin/Dungeons-and-Diffusion', | |
'', | |
'', | |
], | |
[ | |
'compare-image-classification-models', | |
'google/vit-base-patch16-224, microsoft/resnet-50', | |
'', | |
'Compare Image Classification Models', | |
'', | |
'', | |
], | |
[ | |
'compare-text-generation-models', | |
'EleutherAI/gpt-j-6B, EleutherAI/gpt-neo-1.3B', | |
'', | |
'Compare Text Generation Models', | |
'', | |
'', | |
], | |
] | |
api = HfApi() | |
def check_if_model_exists(model_name: str) -> bool: | |
return any(info.modelId == model_name | |
for info in api.list_models(search=model_name)) | |
def check_if_model_loadable(model_name: str) -> bool: | |
try: | |
gr.Interface.load(model_name, src='models') | |
except Exception: | |
return False | |
return True | |
def get_model_io_types( | |
model_name: str) -> tuple[tuple[str, ...], tuple[str, ...]]: | |
iface = gr.Interface.load(model_name, src='models') | |
inputs = tuple(map(str, iface.input_components)) | |
outputs = tuple(map(str, iface.output_components)) | |
return inputs, outputs | |
def check_if_model_io_is_consistent(model_names: list[str]) -> bool: | |
if len(model_names) == 1: | |
return True | |
inputs0, outputs0 = get_model_io_types(model_names[0]) | |
for name in model_names[1:]: | |
inputs, outputs = get_model_io_types(name) | |
if inputs != inputs0 or outputs != outputs0: | |
return False | |
return True | |
def save_space_info(dirname: str, filename: str, content: str) -> None: | |
with open(f'{dirname}/{filename}', 'w') as f: | |
f.write(content) | |
def run(space_name: str, model_names_str: str, hf_token: str, title: str, | |
description: str, article: str) -> str: | |
if space_name == '': | |
return 'Space Name must be specified.' | |
if model_names_str == '': | |
return 'Model Names must be specified.' | |
if hf_token == '': | |
return 'Hugging Face Token must be specified.' | |
model_names = [name.strip() for name in model_names_str.split(',')] | |
model_names_str = '\n'.join(model_names) | |
missing_models = [ | |
name for name in model_names if not check_if_model_exists(name) | |
] | |
if len(missing_models) > 0: | |
message = 'The following models were not found: ' | |
for model_name in missing_models: | |
message += f'\n{model_name}' | |
return message | |
non_loadable_models = [ | |
name for name in model_names if not check_if_model_loadable(name) | |
] | |
if len(non_loadable_models) > 0: | |
message = 'The following models are not loadable with gradio.Interface.load: ' | |
for model_name in non_loadable_models: | |
message += f'\n{model_name}' | |
return message | |
if not check_if_model_io_is_consistent(model_names): | |
return 'The inputs and outputs of each model must be the same.' | |
user_name = api.whoami(token=hf_token)['name'] | |
repo_id = f'{user_name}/{space_name}' | |
try: | |
space_url = api.create_repo(repo_id=repo_id, | |
repo_type='space', | |
private=True, | |
token=hf_token, | |
space_sdk='gradio') | |
card = RepoCard.load(repo_id, repo_type="space") | |
# Update to any version you like :) | |
card.data.sdk_version = "3.7.1" | |
# Push! Make sure to specify repo_type | |
readme_url = card.push_to_hub(repo_id, repo_type='space') | |
| |
except Exception as e: | |
return str(e) | |
with tempfile.TemporaryDirectory() as temp_dir: | |
shutil.copy('assets/template.py', f'{temp_dir}/app.py') | |
save_space_info(temp_dir, 'TITLE', title) | |
save_space_info(temp_dir, 'DESCRIPTION', description) | |
save_space_info(temp_dir, 'ARTICLE', article) | |
save_space_info(temp_dir, 'MODEL_NAMES', model_names_str) | |
api.upload_folder(repo_id=repo_id, | |
folder_path=temp_dir, | |
path_in_repo='.', | |
token=hf_token, | |
repo_type='space') | |
| |
return f'Successfully created: {space_url}' | |
gr.Interface( | |
fn=run, | |
inputs=[ | |
gr.Textbox( | |
label='Space Name', | |
placeholder= | |
'e.g. demo-resnet-50. The Space will be created under your account and private.' | |
), | |
gr.Textbox(label='Model Names', | |
placeholder='e.g. microsoft/resnet-50'), | |
gr.Textbox( | |
label='Hugging Face Token', | |
placeholder= | |
'This should be a token with write permission. See: https://huggingface.co/settings/tokens' | |
), | |
gr.Textbox(label='Title (Optional)'), | |
gr.Textbox(label='Description (Optional)'), | |
gr.Textbox(label='Article (Optional)'), | |
], | |
outputs=gr.Textbox(label='Output'), | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
cache_examples=False, | |
).launch(enable_queue=True, share=False) | |