hysts
commited on
Commit
•
8326154
1
Parent(s):
d6bd851
Check model inputs/outputs
Browse files
app.py
CHANGED
@@ -56,6 +56,26 @@ def check_if_model_loadable(model_name: str) -> bool:
|
|
56 |
return True
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
def save_space_info(dirname: str, filename: str, content: str) -> None:
|
60 |
with open(f'{dirname}/{filename}', 'w') as f:
|
61 |
f.write(content)
|
@@ -87,6 +107,9 @@ def run(space_name: str, model_names_str: str, hf_token: str, title: str,
|
|
87 |
message += f'\n{model_name}'
|
88 |
return message
|
89 |
|
|
|
|
|
|
|
90 |
try:
|
91 |
space_url = api.create_repo(repo_id=space_name,
|
92 |
repo_type='space',
|
|
|
56 |
return True
|
57 |
|
58 |
|
59 |
+
def get_model_io_types(
|
60 |
+
model_name: str) -> tuple[tuple[str, ...], tuple[str, ...]]:
|
61 |
+
iface = gr.Interface.load(model_name, src='models')
|
62 |
+
inputs = tuple(map(str, iface.input_components))
|
63 |
+
outputs = tuple(map(str, iface.output_components))
|
64 |
+
return inputs, outputs
|
65 |
+
|
66 |
+
|
67 |
+
def check_if_model_io_is_consistent(model_names: list[str]) -> bool:
|
68 |
+
if len(model_names) == 1:
|
69 |
+
return True
|
70 |
+
|
71 |
+
inputs0, outputs0 = get_model_io_types(model_names[0])
|
72 |
+
for name in model_names[1:]:
|
73 |
+
inputs, outputs = get_model_io_types(name)
|
74 |
+
if inputs != inputs0 or outputs != outputs0:
|
75 |
+
return False
|
76 |
+
return True
|
77 |
+
|
78 |
+
|
79 |
def save_space_info(dirname: str, filename: str, content: str) -> None:
|
80 |
with open(f'{dirname}/{filename}', 'w') as f:
|
81 |
f.write(content)
|
|
|
107 |
message += f'\n{model_name}'
|
108 |
return message
|
109 |
|
110 |
+
if not check_if_model_io_is_consistent(model_names):
|
111 |
+
return 'The inputs and outputs of each model must be the same.'
|
112 |
+
|
113 |
try:
|
114 |
space_url = api.create_repo(repo_id=space_name,
|
115 |
repo_type='space',
|