hysts commited on
Commit
8326154
1 Parent(s): d6bd851

Check model inputs/outputs

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py CHANGED
@@ -56,6 +56,26 @@ def check_if_model_loadable(model_name: str) -> bool:
56
  return True
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def save_space_info(dirname: str, filename: str, content: str) -> None:
60
  with open(f'{dirname}/{filename}', 'w') as f:
61
  f.write(content)
@@ -87,6 +107,9 @@ def run(space_name: str, model_names_str: str, hf_token: str, title: str,
87
  message += f'\n{model_name}'
88
  return message
89
 
 
 
 
90
  try:
91
  space_url = api.create_repo(repo_id=space_name,
92
  repo_type='space',
 
56
  return True
57
 
58
 
59
+ def get_model_io_types(
60
+ model_name: str) -> tuple[tuple[str, ...], tuple[str, ...]]:
61
+ iface = gr.Interface.load(model_name, src='models')
62
+ inputs = tuple(map(str, iface.input_components))
63
+ outputs = tuple(map(str, iface.output_components))
64
+ return inputs, outputs
65
+
66
+
67
+ def check_if_model_io_is_consistent(model_names: list[str]) -> bool:
68
+ if len(model_names) == 1:
69
+ return True
70
+
71
+ inputs0, outputs0 = get_model_io_types(model_names[0])
72
+ for name in model_names[1:]:
73
+ inputs, outputs = get_model_io_types(name)
74
+ if inputs != inputs0 or outputs != outputs0:
75
+ return False
76
+ return True
77
+
78
+
79
  def save_space_info(dirname: str, filename: str, content: str) -> None:
80
  with open(f'{dirname}/{filename}', 'w') as f:
81
  f.write(content)
 
107
  message += f'\n{model_name}'
108
  return message
109
 
110
+ if not check_if_model_io_is_consistent(model_names):
111
+ return 'The inputs and outputs of each model must be the same.'
112
+
113
  try:
114
  space_url = api.create_repo(repo_id=space_name,
115
  repo_type='space',