hysts
Check model inputs/outputs
8326154
raw
history blame
4.77 kB
#!/usr/bin/env python
from __future__ import annotations
import shutil
import tempfile
import gradio as gr
from huggingface_hub import HfApi
title = 'Model Demo Creation'
description = '''
With this Space, you can create a demo Space for models that are loadable with `gradio.Interface.load` in Model Hub.
'''
article = ''
examples = [
[
'',
'google/vit-base-patch16-224',
'',
'Demo for google/vit-base-patch16-224',
'',
'',
],
[
'',
'google/vit-base-patch16-224, microsoft/resnet-50',
'',
'Compare Image Classification Models',
'',
'',
],
[
'',
'EleutherAI/gpt-j-6B, EleutherAI/gpt-neo-1.3B',
'',
'Compare Text Generation Models',
'',
'',
],
]
api = HfApi()
def check_if_model_exists(model_name: str) -> bool:
return any(info.modelId == model_name
for info in api.list_models(search=model_name))
def check_if_model_loadable(model_name: str) -> bool:
try:
gr.Interface.load(model_name, src='models')
except Exception:
return False
return True
def get_model_io_types(
model_name: str) -> tuple[tuple[str, ...], tuple[str, ...]]:
iface = gr.Interface.load(model_name, src='models')
inputs = tuple(map(str, iface.input_components))
outputs = tuple(map(str, iface.output_components))
return inputs, outputs
def check_if_model_io_is_consistent(model_names: list[str]) -> bool:
if len(model_names) == 1:
return True
inputs0, outputs0 = get_model_io_types(model_names[0])
for name in model_names[1:]:
inputs, outputs = get_model_io_types(name)
if inputs != inputs0 or outputs != outputs0:
return False
return True
def save_space_info(dirname: str, filename: str, content: str) -> None:
with open(f'{dirname}/{filename}', 'w') as f:
f.write(content)
def run(space_name: str, model_names_str: str, hf_token: str, title: str,
description: str, article: str) -> str:
model_names = [name.strip() for name in model_names_str.split(',')]
model_names_str = '\n'.join(model_names)
if len(model_names) == 0:
return 'Model Name cannot be empty.'
missing_models = [
name for name in model_names if not check_if_model_exists(name)
]
if len(missing_models) > 0:
message = 'The following models were not found: '
for model_name in missing_models:
message += f'\n{model_name}'
return message
non_loadable_models = [
name for name in model_names if not check_if_model_loadable(name)
]
if len(non_loadable_models) > 0:
message = 'The following models are not loadable with gradio.Interface.load: '
for model_name in non_loadable_models:
message += f'\n{model_name}'
return message
if not check_if_model_io_is_consistent(model_names):
return 'The inputs and outputs of each model must be the same.'
try:
space_url = api.create_repo(repo_id=space_name,
repo_type='space',
private=True,
token=hf_token,
space_sdk='gradio')
except Exception as e:
return str(e)
with tempfile.TemporaryDirectory() as temp_dir:
shutil.copy('assets/template.py', f'{temp_dir}/app.py')
save_space_info(temp_dir, 'TITLE', title)
save_space_info(temp_dir, 'DESCRIPTION', description)
save_space_info(temp_dir, 'ARTICLE', article)
save_space_info(temp_dir, 'MODEL_NAMES', model_names_str)
api.upload_folder(repo_id=space_name,
folder_path=temp_dir,
path_in_repo='.',
token=hf_token,
repo_type='space')
return f'Successfully created: {space_url}'
gr.Interface(
fn=run,
inputs=[
gr.Textbox(label='Space Name', placeholder='<user_name>/<space_name>'),
gr.Textbox(label='Model Names',
placeholder='e.g. microsoft/resnet-50'),
gr.Textbox(
label='Hugging Face Token',
placeholder=
'This should be a token with write permission. See: https://huggingface.co/settings/tokens'
),
gr.Textbox(label='Title (Optional)'),
gr.Textbox(label='Description (Optional)'),
gr.Textbox(label='Article (Optional)'),
],
outputs=gr.Textbox(label='Output'),
title=title,
description=description,
article=article,
examples=examples,
cache_examples=False,
).launch(enable_queue=True, share=False)