davanstrien HF staff commited on
Commit
dba982b
1 Parent(s): 3352738

add linked models option

Browse files
Files changed (1) hide show
  1. app.py +42 -11
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  from dotenv import load_dotenv
7
  from qdrant_client import QdrantClient, models
8
  from sentence_transformers import SentenceTransformer
 
9
 
10
  load_dotenv()
11
 
@@ -22,7 +23,7 @@ client = QdrantClient(
22
  )
23
 
24
 
25
- def format_results(results):
26
  markdown = (
27
  "<h1 style='text-align: center;'> &#x2728; Dataset Search Results &#x2728;"
28
  " </h1> \n\n"
@@ -35,12 +36,31 @@ def format_results(results):
35
  markdown += header + "\n"
36
  markdown += f"**Downloads:** {download_number}\n\n"
37
  markdown += f"{result.payload['section_text']} \n"
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  return markdown
40
 
41
 
42
  @lru_cache(maxsize=100_000)
43
- def search(query: str, limit: Optional[int] = 10):
 
 
 
 
 
 
 
 
44
  query_ = sentence_embedding_model.encode(
45
  f"Represent this sentence for searching relevant passages:{query}"
46
  )
@@ -49,7 +69,7 @@ def search(query: str, limit: Optional[int] = 10):
49
  query_vector=query_,
50
  limit=limit,
51
  )
52
- return format_results(results)
53
 
54
 
55
  @lru_cache(maxsize=100_000)
@@ -69,25 +89,30 @@ def hub_id_qdrant_id(hub_id):
69
  return matches[0][0].id
70
  except IndexError as e:
71
  raise gr.Error(
72
- f"Hub id {hub_id} not in out database. This could be because it is very new"
73
  " or because it doesn't have much documentation."
74
  ) from e
75
 
76
 
77
  @lru_cache()
78
- def recommend(hub_id, limit: Optional[int] = 10):
79
  positive_id = hub_id_qdrant_id(hub_id)
80
  results = client.recommend(
81
  collection_name=collection_name, positive=[positive_id], limit=limit
82
  )
83
- return format_results(results)
84
 
85
 
86
- def query(search_term, search_type, limit: Optional[int] = 10):
 
 
 
 
 
87
  if search_type == "Recommend similar datasets":
88
- return recommend(search_term, limit)
89
  else:
90
- return search(search_term, limit)
91
 
92
 
93
  with gr.Blocks() as demo:
@@ -120,10 +145,16 @@ with gr.Blocks() as demo:
120
  step=1,
121
  value=10,
122
  label="Maximum number of results",
123
- help="This is the maximum number of results that will be returned",
124
  )
 
 
 
 
 
125
  results = gr.Markdown()
126
- find_similar_btn.click(query, [search_term, search_type, max_results], results)
 
 
127
 
128
 
129
  demo.launch()
 
6
  from dotenv import load_dotenv
7
  from qdrant_client import QdrantClient, models
8
  from sentence_transformers import SentenceTransformer
9
+ from huggingface_hub import list_models
10
 
11
  load_dotenv()
12
 
 
23
  )
24
 
25
 
26
+ def format_results(results, show_associated_models=True):
27
  markdown = (
28
  "<h1 style='text-align: center;'> &#x2728; Dataset Search Results &#x2728;"
29
  " </h1> \n\n"
 
36
  markdown += header + "\n"
37
  markdown += f"**Downloads:** {download_number}\n\n"
38
  markdown += f"{result.payload['section_text']} \n"
39
+ if show_associated_models:
40
+ if linked_models := get_models_for_dataset(hub_id):
41
+ linked_models = [
42
+ f"[{model}](https://huggingface.co/{model})"
43
+ for model in linked_models
44
+ ]
45
+ markdown += (
46
+ "<details><summary>Models trained on this dataset</summary>\n\n"
47
+ )
48
+ markdown += "- " + "\n- ".join(linked_models) + "\n\n"
49
+ markdown += "</details>\n\n"
50
 
51
  return markdown
52
 
53
 
54
  @lru_cache(maxsize=100_000)
55
+ def get_models_for_dataset(id):
56
+ results = list(iter(list_models(filter=f"dataset:{id}")))
57
+ if results:
58
+ results = list({result.id for result in results})
59
+ return results
60
+
61
+
62
+ @lru_cache(maxsize=200_000)
63
+ def search(query: str, limit: Optional[int] = 10, show_linked_models: bool = False):
64
  query_ = sentence_embedding_model.encode(
65
  f"Represent this sentence for searching relevant passages:{query}"
66
  )
 
69
  query_vector=query_,
70
  limit=limit,
71
  )
72
+ return format_results(results, show_associated_models=show_linked_models)
73
 
74
 
75
  @lru_cache(maxsize=100_000)
 
89
  return matches[0][0].id
90
  except IndexError as e:
91
  raise gr.Error(
92
+ f"Hub id {hub_id} not in the database. This could be because it is very new"
93
  " or because it doesn't have much documentation."
94
  ) from e
95
 
96
 
97
  @lru_cache()
98
+ def recommend(hub_id, limit: Optional[int] = 10, show_linked_models=False):
99
  positive_id = hub_id_qdrant_id(hub_id)
100
  results = client.recommend(
101
  collection_name=collection_name, positive=[positive_id], limit=limit
102
  )
103
+ return format_results(results, show_associated_models=show_linked_models)
104
 
105
 
106
+ def query(
107
+ search_term,
108
+ search_type,
109
+ limit: Optional[int] = 10,
110
+ show_linked_models: bool = False,
111
+ ):
112
  if search_type == "Recommend similar datasets":
113
+ return recommend(search_term, limit, show_linked_models)
114
  else:
115
+ return search(search_term, limit, show_linked_models)
116
 
117
 
118
  with gr.Blocks() as demo:
 
145
  step=1,
146
  value=10,
147
  label="Maximum number of results",
 
148
  )
149
+ show_linked_models = gr.Checkbox(
150
+ label="Show associated models",
151
+ default=False,
152
+ )
153
+
154
  results = gr.Markdown()
155
+ find_similar_btn.click(
156
+ query, [search_term, search_type, max_results, show_linked_models], results
157
+ )
158
 
159
 
160
  demo.launch()