nouamanetazi HF staff commited on
Commit
cf61e60
1 Parent(s): da75a62

update routes

Browse files
Files changed (3) hide show
  1. pages/search_engine.py +23 -32
  2. server/api.py +37 -2
  3. streamlit_app.py +11 -0
pages/search_engine.py CHANGED
@@ -50,25 +50,12 @@ def paginator(label, articles, articles_per_page=10, on_sidebar=True):
50
 
51
  return itertools.islice(enumerate(articles), min_index, max_index)
52
 
53
-
54
  def page():
55
- st.set_page_config(
56
- page_title="HF Search Engine",
57
- page_icon="🔎",
58
- layout="wide",
59
- initial_sidebar_state="auto",
60
- # menu_items={
61
- # "Get Help": "https://www.extremelycoolapp.com/help",
62
- # "Report a bug": "https://www.extremelycoolapp.com/bug",
63
- # "About": "# This is a header. This is an *extremely* cool app!",
64
- # },
65
- )
66
-
67
  ### SIDEBAR
68
  search_backend = st.sidebar.selectbox(
69
- "Search Engine",
70
- ["hfapi", "custom"],
71
- format_func=lambda x: {"hfapi": "Huggingface API", "custom": "Sentence Bert"}[x],
72
  )
73
  limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
74
 
@@ -112,22 +99,22 @@ def page():
112
  if search_query != "":
113
  response = requests.post(search_url, headers=headers, json=search_body).json()
114
 
115
- record_list = []
116
  _ = [
117
- record_list.append(
118
  {
119
- "modelId": record["modelId"],
120
- "tags": record["tags"],
121
- "downloads": record["downloads"],
122
- "likes": record["likes"],
 
123
  }
124
  )
125
- for record in response.get("value")
126
  ]
127
 
128
- # filter results
129
 
130
- if record_list:
131
  st.write(f'Search results ({response.get("count")}):')
132
 
133
  if response.get("count") > 100:
@@ -135,16 +122,20 @@ def page():
135
  else:
136
  shown_results = response.get("count")
137
 
138
- for i, record in paginator(
139
  f"Select results (showing {shown_results} of {response.get('count')} results)",
140
- record_list,
141
  ):
142
  col1, col2, col3 = st.columns([5,1,1])
143
- col1.metric("Model", record["modelId"])
144
- col2.metric("N° downloads", numerize(record["downloads"]))
145
- col3.metric("N° likes", numerize(record["likes"]))
146
- st.button(f"View model", on_click=lambda record=record: webbrowser.open(f"https://huggingface.co/{record['modelId']})"), key=record["modelId"])
147
- st.markdown(f"**Tags:** {' • '.join(record['tags'])}")
 
 
 
 
148
 
149
  # TODO: embed huggingface spaces
150
  # import streamlit.components.v1 as components
 
50
 
51
  return itertools.islice(enumerate(articles), min_index, max_index)
52
 
 
53
  def page():
 
 
 
 
 
 
 
 
 
 
 
 
54
  ### SIDEBAR
55
  search_backend = st.sidebar.selectbox(
56
+ "Search method",
57
+ ["semantic", "bm25", "hfapi"],
58
+ format_func=lambda x: {"hfapi": "Keyword search", "bm25": "BM25 search", "semantic": "Semantic Search"}[x],
59
  )
60
  limit_results = st.sidebar.number_input("Limit results", min_value=0, value=10)
61
 
 
99
  if search_query != "":
100
  response = requests.post(search_url, headers=headers, json=search_body).json()
101
 
102
+ hit_list = []
103
  _ = [
104
+ hit_list.append(
105
  {
106
+ "modelId": hit["modelId"],
107
+ "tags": hit["tags"],
108
+ "downloads": hit["downloads"],
109
+ "likes": hit["likes"],
110
+ "readme": hit.get("readme", None),
111
  }
112
  )
113
+ for hit in response.get("value")
114
  ]
115
 
 
116
 
117
+ if hit_list:
118
  st.write(f'Search results ({response.get("count")}):')
119
 
120
  if response.get("count") > 100:
 
122
  else:
123
  shown_results = response.get("count")
124
 
125
+ for i, hit in paginator(
126
  f"Select results (showing {shown_results} of {response.get('count')} results)",
127
+ hit_list,
128
  ):
129
  col1, col2, col3 = st.columns([5,1,1])
130
+ col1.metric("Model", hit["modelId"])
131
+ col2.metric("N° downloads", numerize(hit["downloads"]))
132
+ col3.metric("N° likes", numerize(hit["likes"]))
133
+ st.button(f"View model on 🤗", on_click=lambda hit=hit: webbrowser.open(f"https://huggingface.co/{hit['modelId']}"), key=hit["modelId"])
134
+ st.markdown(f"**Tags:** {' • '.join(hit['tags'])}")
135
+
136
+ if hit["readme"]:
137
+ with st.expander("See README"):
138
+ st.write(hit["readme"])
139
 
140
  # TODO: embed huggingface spaces
141
  # import streamlit.components.v1 as components
server/api.py CHANGED
@@ -46,8 +46,8 @@ def hf_api():
46
  return json.dumps({"value": hits, "count": count})
47
 
48
 
49
- @app.route("/custom/search", methods=["POST"])
50
- def main():
51
  request_data = request.get_json()
52
  query = request_data.get("query")
53
  filters = json.loads(request_data.get("filters"))
@@ -58,6 +58,41 @@ def main():
58
 
59
  # TODO: filters
60
  hits = hf_search(query=query, method="retrieve & rerank", limit=limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  return json.dumps({"value": hits, "count": len(hits)})
63
 
 
46
  return json.dumps({"value": hits, "count": count})
47
 
48
 
49
+ @app.route("/semantic/search", methods=["POST"])
50
+ def semantic_search():
51
  request_data = request.get_json()
52
  query = request_data.get("query")
53
  filters = json.loads(request_data.get("filters"))
 
58
 
59
  # TODO: filters
60
  hits = hf_search(query=query, method="retrieve & rerank", limit=limit)
61
+ hits = [
62
+ {
63
+ "modelId": hit["modelId"],
64
+ "tags": hit["tags"],
65
+ "downloads": hit["downloads"],
66
+ "likes": hit["likes"],
67
+ "readme": hit.get("readme", None),
68
+ }
69
+ for hit in hits
70
+ ]
71
+ return json.dumps({"value": hits, "count": len(hits)})
72
+
73
+ @app.route("/bm25/search", methods=["POST"])
74
+ def bm25_search():
75
+ request_data = request.get_json()
76
+ query = request_data.get("query")
77
+ filters = json.loads(request_data.get("filters"))
78
+ limit = request_data.get("limit", 5)
79
+ print("query", query)
80
+ print("filters", filters)
81
+ print("limit", limit)
82
+
83
+ # TODO: filters
84
+ hits = hf_search(query=query, method="bm25", limit=limit)
85
+ hits = [
86
+ {
87
+ "modelId": hit["modelId"],
88
+ "tags": hit["tags"],
89
+ "downloads": hit["downloads"],
90
+ "likes": hit["likes"],
91
+ "readme": hit.get("readme", None),
92
+ }
93
+ for hit in hits
94
+ ]
95
+ pprint(hits)
96
 
97
  return json.dumps({"value": hits, "count": len(hits)})
98
 
streamlit_app.py CHANGED
@@ -11,6 +11,17 @@ def set_record(record):
11
 
12
 
13
  if not st.session_state["selected_record"]: # search engine page
 
 
 
 
 
 
 
 
 
 
 
14
  search_engine_page()
15
 
16
  else: # a record has been selected
 
11
 
12
 
13
  if not st.session_state["selected_record"]: # search engine page
14
+ st.set_page_config(
15
+ page_title="HuggingFace Search Engine",
16
+ page_icon="🔎",
17
+ layout="wide",
18
+ initial_sidebar_state="auto",
19
+ # menu_items={
20
+ # "Get Help": "https://www.extremelycoolapp.com/help",
21
+ # "Report a bug": "https://www.extremelycoolapp.com/bug",
22
+ # "About": "# This is a header. This is an *extremely* cool app!",
23
+ # },
24
+ )
25
  search_engine_page()
26
 
27
  else: # a record has been selected