katielink commited on
Commit
bed5b68
1 Parent(s): 511bdd4

Fix filtering

Browse files
__pycache__/model_list.cpython-311.pyc CHANGED
Binary files a/__pycache__/model_list.cpython-311.pyc and b/__pycache__/model_list.cpython-311.pyc differ
 
app.py CHANGED
@@ -23,7 +23,9 @@ def main():
23
  placeholder=
24
  'You can search for titles with regular expressions. e.g. (?<!sur)face',
25
  max_lines=1)
 
26
  case_sensitive = gr.Checkbox(label='Case Sensitive')
 
27
  filter_names = gr.CheckboxGroup(choices=[
28
  'Paper',
29
  'Code',
@@ -36,10 +38,10 @@ def main():
36
  'Scientific',
37
  ]
38
 
39
- # prev: paper_sessions
40
  data_types = gr.CheckboxGroup(choices=data_type_names,
41
  value=data_type_names,
42
  label='Training Data Type(s)')
 
43
  search_button = gr.Button('Search')
44
 
45
  number_of_models = gr.Textbox(label='Number of Models Found')
 
23
  placeholder=
24
  'You can search for titles with regular expressions. e.g. (?<!sur)face',
25
  max_lines=1)
26
+
27
  case_sensitive = gr.Checkbox(label='Case Sensitive')
28
+
29
  filter_names = gr.CheckboxGroup(choices=[
30
  'Paper',
31
  'Code',
 
38
  'Scientific',
39
  ]
40
 
 
41
  data_types = gr.CheckboxGroup(choices=data_type_names,
42
  value=data_type_names,
43
  label='Training Data Type(s)')
44
+
45
  search_button = gr.Button('Search')
46
 
47
  number_of_models = gr.Textbox(label='Number of Models Found')
model_list.py CHANGED
@@ -15,7 +15,7 @@ class ModelList:
15
 
16
  self.table_header = '''
17
  <tr>
18
- <td width="40%">Model Name</td>
19
  <td width="10%">Data Type(s)</td>
20
  <td width="10%">Year Published</td>
21
  <td width="10%">Paper</td>
@@ -50,9 +50,10 @@ class ModelList:
50
  rows.append(row)
51
  self.table['html_table_content'] = rows
52
 
53
- def render(self, search_query: str, case_sensitive: bool,
54
- filter_names: list[str],
55
- data_types: list[str]) -> tuple[int, str]:
 
56
  df = self.table
57
  if search_query:
58
  if case_sensitive:
@@ -60,15 +61,14 @@ class ModelList:
60
  else:
61
  df = df[df.name_lowercase.str.contains(search_query.lower())]
62
  has_paper = 'Paper' in filter_names
63
- has_github = 'Github' in filter_names
64
- has_model = 'Hub Model' in filter_names or 'Other Weights' in filter_names
65
  df = self.filter_table(df, has_paper, has_github, has_model, data_types)
66
  return len(df), self.to_html(df, self.table_header)
67
 
68
  @staticmethod
69
  def filter_table(df: pd.DataFrame, has_paper: bool, has_github: bool,
70
- has_model: bool,
71
- data_types: list[str]) -> pd.DataFrame:
72
  if has_paper:
73
  df = df[~df.paper.isna()]
74
  if has_github:
 
15
 
16
  self.table_header = '''
17
  <tr>
18
+ <td width="20%">Model Name</td>
19
  <td width="10%">Data Type(s)</td>
20
  <td width="10%">Year Published</td>
21
  <td width="10%">Paper</td>
 
50
  rows.append(row)
51
  self.table['html_table_content'] = rows
52
 
53
+ def render(self, search_query: str,
54
+ case_sensitive: bool,
55
+ filter_names: list[str],
56
+ data_types: list[str]) -> tuple[int, str]:
57
  df = self.table
58
  if search_query:
59
  if case_sensitive:
 
61
  else:
62
  df = df[df.name_lowercase.str.contains(search_query.lower())]
63
  has_paper = 'Paper' in filter_names
64
+ has_github = 'Code' in filter_names
65
+ has_model = 'Model Weights' in filter_names
66
  df = self.filter_table(df, has_paper, has_github, has_model, data_types)
67
  return len(df), self.to_html(df, self.table_header)
68
 
69
  @staticmethod
70
  def filter_table(df: pd.DataFrame, has_paper: bool, has_github: bool,
71
+ has_model: bool, data_types: list[str]) -> pd.DataFrame:
 
72
  if has_paper:
73
  df = df[~df.paper.isna()]
74
  if has_github: