Spaces:
Running
Running
thomasht86
commited on
Commit
•
8996eb9
1
Parent(s):
580ca24
deploy at 2024-08-24 17:35:22.783475
Browse files- main copy.py +861 -0
- main.py +11 -43
main copy.py
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fasthtml_hf import setup_hf_backup
|
2 |
+
from fasthtml.common import (
|
3 |
+
picolink,
|
4 |
+
serve,
|
5 |
+
Div,
|
6 |
+
Title,
|
7 |
+
Main,
|
8 |
+
Input,
|
9 |
+
Button,
|
10 |
+
A,
|
11 |
+
Section,
|
12 |
+
H2,
|
13 |
+
Ul,
|
14 |
+
Li,
|
15 |
+
P,
|
16 |
+
Img,
|
17 |
+
Details,
|
18 |
+
MarkdownJS,
|
19 |
+
HighlightJS,
|
20 |
+
Summary,
|
21 |
+
Script,
|
22 |
+
I,
|
23 |
+
Form,
|
24 |
+
RedirectResponse,
|
25 |
+
dataclass,
|
26 |
+
Favicon,
|
27 |
+
database,
|
28 |
+
get_key,
|
29 |
+
Table,
|
30 |
+
Thead,
|
31 |
+
Tr,
|
32 |
+
Th,
|
33 |
+
Tbody,
|
34 |
+
Td,
|
35 |
+
FileResponse,
|
36 |
+
fast_app,
|
37 |
+
Beforeware,
|
38 |
+
Hidden,
|
39 |
+
Request,
|
40 |
+
H3,
|
41 |
+
Style,
|
42 |
+
)
|
43 |
+
from fasthtml.components import Nav, Article, Header, Mark
|
44 |
+
from fasthtml.pico import Search, Grid, Fieldset, Label
|
45 |
+
from starlette.middleware import Middleware
|
46 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
47 |
+
from starlette.middleware.sessions import SessionMiddleware
|
48 |
+
from vespa.application import Vespa
|
49 |
+
import json
|
50 |
+
import os
|
51 |
+
import re
|
52 |
+
import time
|
53 |
+
from hmac import compare_digest
|
54 |
+
from io import StringIO
|
55 |
+
import csv
|
56 |
+
import tempfile
|
57 |
+
from enum import Enum
|
58 |
+
from typing import Tuple as T
|
59 |
+
from urllib.parse import quote
|
60 |
+
import uuid
|
61 |
+
|
62 |
+
DEV_MODE = False
|
63 |
+
|
64 |
+
if DEV_MODE:
|
65 |
+
print("Running in DEV_MODE - Hot reload enabled")
|
66 |
+
print("Loading environment variables from .env")
|
67 |
+
from dotenv import load_dotenv
|
68 |
+
|
69 |
+
load_dotenv()
|
70 |
+
else:
|
71 |
+
print("DEV_MODE disabled - environment variables loaded from system")
|
72 |
+
|
73 |
+
vespa_app_url = os.getenv("VESPA_APP_URL", None)
|
74 |
+
if vespa_app_url is None:
|
75 |
+
print("Please set the VESPA_APP_URL environment variable")
|
76 |
+
exit(1)
|
77 |
+
|
78 |
+
ADMIN_NAME = os.getenv("ADMIN_NAME", "admin")
|
79 |
+
ADMIN_PWD = os.getenv("ADMIN_PWD", "admin")
|
80 |
+
|
81 |
+
vespa_app: Vespa = Vespa(
|
82 |
+
url=vespa_app_url,
|
83 |
+
vespa_cloud_secret_token=os.getenv("VESPA_CLOUD_SECRET_TOKEN"),
|
84 |
+
)
|
85 |
+
status = vespa_app.get_application_status()
|
86 |
+
if status is None:
|
87 |
+
print("Could not connect to Vespa application")
|
88 |
+
else:
|
89 |
+
print("Connected to Vespa application!")
|
90 |
+
|
91 |
+
fa = Script(src="https://kit.fontawesome.com/664eb1a115.js", crossorigin="anonymous")
|
92 |
+
favicon = Favicon(
|
93 |
+
"https://search.vespa.ai/favicon.ico",
|
94 |
+
"https://search.vespa.ai/favicon.ico",
|
95 |
+
)
|
96 |
+
DB_FILE = "db/vespa.db"
|
97 |
+
db = database(DB_FILE)
|
98 |
+
queries = db.t.queries
|
99 |
+
if queries not in db.t:
|
100 |
+
# You can pass a dict, or kwargs, to most MiniDataAPI methods.
|
101 |
+
queries.create(
|
102 |
+
dict(qid=int, query=str, ranking=str, sess_id=str, timestamp=int), pk="qid"
|
103 |
+
)
|
104 |
+
# Add autoincrement to the qid column
|
105 |
+
db.query("ALTER TABLE queries ADD COLUMN qid INTEGER PRIMARY KEY AUTOINCREMENT")
|
106 |
+
Query = queries.dataclass()
|
107 |
+
|
108 |
+
# Add a classmethod to the Query dataclass to convert timestamp field to a human readable format
|
109 |
+
Query.get_datetime = lambda self: time.strftime(
|
110 |
+
"%Y-%m-%d %H:%M:%S", time.localtime(self.timestamp)
|
111 |
+
)
|
112 |
+
|
113 |
+
# Status code 303 is a redirect that can change POST to GET,
|
114 |
+
# so it's appropriate for a login page.
|
115 |
+
login_redir = RedirectResponse("/login", status_code=303)
|
116 |
+
|
117 |
+
|
118 |
+
def user_auth_before(req, sess):
|
119 |
+
# The `auth` key in the request scope is automatically provided
|
120 |
+
# to any handler which requests it, and can not be injected
|
121 |
+
# by the user using query params, cookies, etc, so it should
|
122 |
+
# be secure to use.
|
123 |
+
print(f"Session Data before route: {sess}")
|
124 |
+
auth = req.scope["auth"] = sess.get("auth", None)
|
125 |
+
print(f"Auth: {auth}")
|
126 |
+
if not auth:
|
127 |
+
return login_redir
|
128 |
+
|
129 |
+
|
130 |
+
spinner_css = Style("""
|
131 |
+
.htmx-indicator {
|
132 |
+
display: none; /* Hide spinner by default */
|
133 |
+
}
|
134 |
+
|
135 |
+
.htmx-indicator.htmx-request {
|
136 |
+
display: block;
|
137 |
+
}
|
138 |
+
""")
|
139 |
+
|
140 |
+
headers = (
|
141 |
+
picolink,
|
142 |
+
MarkdownJS(),
|
143 |
+
HighlightJS(langs=["json", "python"]),
|
144 |
+
favicon,
|
145 |
+
fa,
|
146 |
+
spinner_css,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Read file contents once before starting the server
|
150 |
+
with open("README.md") as f:
|
151 |
+
README = f.read()
|
152 |
+
with open("main.py") as f:
|
153 |
+
SOURCE = f.read()
|
154 |
+
|
155 |
+
# Sesskey
|
156 |
+
sess_key_path = "session/.sesskey"
|
157 |
+
# Make sure session directory exists
|
158 |
+
os.makedirs("session", exist_ok=True)
|
159 |
+
|
160 |
+
|
161 |
+
# Middleware
|
162 |
+
class XFrameOptionsMiddleware(BaseHTTPMiddleware):
|
163 |
+
async def dispatch(self, request, call_next):
|
164 |
+
response = await call_next(request)
|
165 |
+
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
166 |
+
return response
|
167 |
+
|
168 |
+
class SessionLoggingMiddleware(BaseHTTPMiddleware):
|
169 |
+
async def dispatch(self, request, call_next):
|
170 |
+
print(f"Before request: Session data: {request.session}")
|
171 |
+
response = await call_next(request)
|
172 |
+
print(f"After request: Session data: {request.session}")
|
173 |
+
return response
|
174 |
+
|
175 |
+
class DebugSessionMiddleware(SessionMiddleware):
|
176 |
+
async def __call__(self, scope, receive, send):
|
177 |
+
print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
|
178 |
+
await super().__call__(scope, receive, send)
|
179 |
+
print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
|
180 |
+
|
181 |
+
from starlette.middleware.cors import CORSMiddleware
|
182 |
+
|
183 |
+
middlewares = [
|
184 |
+
Middleware(
|
185 |
+
SessionMiddleware,
|
186 |
+
secret_key=get_key(fname=sess_key_path),
|
187 |
+
max_age=3600,
|
188 |
+
#same_site='lax',
|
189 |
+
),
|
190 |
+
Middleware(CORSMiddleware, allow_origins=['*']),
|
191 |
+
Middleware(XFrameOptionsMiddleware),
|
192 |
+
Middleware(SessionLoggingMiddleware),
|
193 |
+
#Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
|
194 |
+
]
|
195 |
+
bware = Beforeware(
|
196 |
+
user_auth_before,
|
197 |
+
skip=[
|
198 |
+
r"/favicon\.ico",
|
199 |
+
r"/static/.*",
|
200 |
+
r".*\.css",
|
201 |
+
r".*\.js",
|
202 |
+
"/",
|
203 |
+
"/login",
|
204 |
+
"/search",
|
205 |
+
"/document/.*",
|
206 |
+
"/expand/.*",
|
207 |
+
"/source",
|
208 |
+
"/about",
|
209 |
+
],
|
210 |
+
)
|
211 |
+
|
212 |
+
app, rt = fast_app(
|
213 |
+
before=bware,
|
214 |
+
live=DEV_MODE,
|
215 |
+
hdrs=headers,
|
216 |
+
middleware=middlewares,
|
217 |
+
key_fname=sess_key_path,
|
218 |
+
same_site="None",
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
sesskey = get_key(fname=sess_key_path)
|
223 |
+
print(f"Session key: {sesskey}")
|
224 |
+
|
225 |
+
|
226 |
+
# enum class for rank profiles
|
227 |
+
class RankProfile(str, Enum):
|
228 |
+
bm25 = "bm25"
|
229 |
+
semantic = "semantic"
|
230 |
+
fusion = "fusion"
|
231 |
+
|
232 |
+
|
233 |
+
def get_navbar(admin: bool):
|
234 |
+
print(f"In get_navbar: {admin}")
|
235 |
+
bar = Nav(
|
236 |
+
Ul(
|
237 |
+
Li(
|
238 |
+
A(
|
239 |
+
Img(src="https://vespa.ai/assets/vespa-ai-logo-heather.svg"),
|
240 |
+
href="https://cloud.vespa.ai",
|
241 |
+
target="_blank",
|
242 |
+
style="margin: 10px;",
|
243 |
+
),
|
244 |
+
)
|
245 |
+
),
|
246 |
+
Ul(H2("Vespa-fastHTML demo")),
|
247 |
+
Ul(
|
248 |
+
# A question mark icon with link to an about page
|
249 |
+
A(
|
250 |
+
I(cls="fa fa-question-circle fa-2x"),
|
251 |
+
href="/about",
|
252 |
+
style="margin: 10px;",
|
253 |
+
title="About this app",
|
254 |
+
),
|
255 |
+
A(
|
256 |
+
I(cls="fab fa-slack fa-2x"),
|
257 |
+
href="https://slack.vespa.ai/",
|
258 |
+
style="margin: 10px;",
|
259 |
+
target="_blank",
|
260 |
+
title="Join Vespa Slack channel",
|
261 |
+
),
|
262 |
+
A(
|
263 |
+
I(cls="fab fa-github fa-2x"),
|
264 |
+
href="https://github.com/vespa-engine/sample-apps/tree/master/examples/fasthtml-demo",
|
265 |
+
style="margin: 10px;",
|
266 |
+
target="_blank",
|
267 |
+
title="View source code on GitHub",
|
268 |
+
),
|
269 |
+
A(
|
270 |
+
I(cls="fa fa-code fa-2x"),
|
271 |
+
href="/source",
|
272 |
+
style="margin: 10px;",
|
273 |
+
title="View source code",
|
274 |
+
),
|
275 |
+
# Login icon (link to /login) show tooltip on hover. MAke it hidden if admin is logged in
|
276 |
+
A(
|
277 |
+
I(cls="fa fa-shield fa-2x"),
|
278 |
+
href="/login" if not admin else "/admin",
|
279 |
+
style="margin: 10px;",
|
280 |
+
title="Admin login",
|
281 |
+
),
|
282 |
+
# Logout icon if admin is logged in
|
283 |
+
A(
|
284 |
+
I(cls="fa fa-sign-out fa-2x"),
|
285 |
+
href="/logout",
|
286 |
+
style="margin: 10px;" if admin else "display: none;",
|
287 |
+
title="Logout",
|
288 |
+
),
|
289 |
+
),
|
290 |
+
# 10px margin to right of navbar
|
291 |
+
style="margin-right: 10px;",
|
292 |
+
)
|
293 |
+
return bar
|
294 |
+
|
295 |
+
|
296 |
+
def spinner_div(hidden: bool = False):
|
297 |
+
return Div(
|
298 |
+
A(
|
299 |
+
id="spinner",
|
300 |
+
aria_busy="true",
|
301 |
+
cls="htmx-indicator",
|
302 |
+
style="font-size: 2em;",
|
303 |
+
),
|
304 |
+
style="text-align: center; margin-top: 40px;"
|
305 |
+
if not hidden
|
306 |
+
else "display: none;",
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
@app.route("/")
|
311 |
+
def get(sess):
|
312 |
+
# Can not get auth directly, as it is skipped in beforeware
|
313 |
+
auth = sess.get("auth", False)
|
314 |
+
queries = [
|
315 |
+
"Breast Cancer Cells Feed on Cholesterol",
|
316 |
+
"Treating Asthma With Plants vs. Pills",
|
317 |
+
"Testing Turmeric on Smokers",
|
318 |
+
"The Role of Pesticides in Parkinson's Disease",
|
319 |
+
]
|
320 |
+
return (
|
321 |
+
Title("Vespa demo"),
|
322 |
+
get_navbar(auth),
|
323 |
+
Main(
|
324 |
+
# Search bar
|
325 |
+
Search(
|
326 |
+
Input(
|
327 |
+
type="search",
|
328 |
+
placeholder="Ask/search for medical information?",
|
329 |
+
id="userquery",
|
330 |
+
),
|
331 |
+
# Get search results on button click with search-input as query parameter
|
332 |
+
Button(
|
333 |
+
"Search",
|
334 |
+
hx_get="/search",
|
335 |
+
# include userquery and id of selected ranking radio button
|
336 |
+
hx_include="#userquery, input[name=ranking]:checked",
|
337 |
+
hx_target="#results",
|
338 |
+
hx_indicator="#spinner",
|
339 |
+
),
|
340 |
+
style="margin: 10% 10px 0 0;",
|
341 |
+
),
|
342 |
+
Fieldset(
|
343 |
+
Input(type="radio", id="bm25", name="ranking", value="bm25"),
|
344 |
+
Label("BM25", htmlfor="bm25"),
|
345 |
+
Input(type="radio", id="semantic", name="ranking", value="semantic"),
|
346 |
+
Label("Semantic", htmlfor="semantic"),
|
347 |
+
Input(
|
348 |
+
type="radio",
|
349 |
+
id="fusion",
|
350 |
+
name="ranking",
|
351 |
+
value="fusion",
|
352 |
+
checked="",
|
353 |
+
),
|
354 |
+
Label("Reciprocal Rank fusion", htmlfor="fusion"),
|
355 |
+
style="margin: 10px; text-align: center;",
|
356 |
+
id="ranking",
|
357 |
+
),
|
358 |
+
H3("Example queries"),
|
359 |
+
# Buttons with predefined search queries
|
360 |
+
Grid(
|
361 |
+
*[
|
362 |
+
Button(
|
363 |
+
query,
|
364 |
+
hx_get="/search?userquery=" + query,
|
365 |
+
hx_include="input[name=ranking]:checked",
|
366 |
+
hx_target="#results",
|
367 |
+
hx_indicator="#spinner",
|
368 |
+
hx_on_click=f"document.getElementById('userquery').value='{query}'",
|
369 |
+
style="margin: 10px; padding: 5px;",
|
370 |
+
cls="secondary outline",
|
371 |
+
id=f"example-{qid}",
|
372 |
+
)
|
373 |
+
for qid, query in enumerate(queries)
|
374 |
+
],
|
375 |
+
# Make the grid buttons have same height and distribute evenly and center align
|
376 |
+
style="grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));",
|
377 |
+
),
|
378 |
+
# Section(
|
379 |
+
# Input(
|
380 |
+
# id="suggestion-input",
|
381 |
+
# list="search-options",
|
382 |
+
# placeholder="Search options",
|
383 |
+
# ),
|
384 |
+
# Datalist(
|
385 |
+
# *[
|
386 |
+
# Option(
|
387 |
+
# "Covid-19",
|
388 |
+
# value="Covid-19",
|
389 |
+
# ),
|
390 |
+
# Option(
|
391 |
+
# "Vaccine",
|
392 |
+
# value="Vaccine",
|
393 |
+
# ),
|
394 |
+
# ],
|
395 |
+
# id="search-options",
|
396 |
+
# ),
|
397 |
+
# id="suggestions",
|
398 |
+
# ),
|
399 |
+
# Display spinner div only if it #spinner does not exist
|
400 |
+
Section(
|
401 |
+
spinner_div(),
|
402 |
+
id="results",
|
403 |
+
hx_swap="innerHTML",
|
404 |
+
style="margin: 20px;",
|
405 |
+
),
|
406 |
+
style="margin: 0 auto; width: 70%;",
|
407 |
+
id="main",
|
408 |
+
),
|
409 |
+
)
|
410 |
+
|
411 |
+
|
412 |
+
@dataclass
|
413 |
+
class Login:
|
414 |
+
name: str
|
415 |
+
pwd: str
|
416 |
+
|
417 |
+
|
418 |
+
@app.get("/login")
|
419 |
+
def get_login_form(sess, error: bool = False):
|
420 |
+
auth = sess.get("auth", False)
|
421 |
+
frm = Form(
|
422 |
+
Input(id="name", placeholder="Name"),
|
423 |
+
Input(id="pwd", type="password", placeholder="Password"),
|
424 |
+
Button("login"),
|
425 |
+
action="/login",
|
426 |
+
method="post",
|
427 |
+
)
|
428 |
+
err_msg = P("Incorrect password", style="color: red;") if error else ""
|
429 |
+
return (
|
430 |
+
Title("Admin login"),
|
431 |
+
get_navbar(auth),
|
432 |
+
Main(
|
433 |
+
err_msg,
|
434 |
+
frm,
|
435 |
+
style="width: 50%; margin: 10% auto;",
|
436 |
+
),
|
437 |
+
)
|
438 |
+
|
439 |
+
|
440 |
+
@app.post("/login")
|
441 |
+
def post(login: Login, sess):
|
442 |
+
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
443 |
+
# Incorrect password - add error message
|
444 |
+
return RedirectResponse("/login?error=True", status_code=303)
|
445 |
+
print(f"Session after setting auth: {sess}")
|
446 |
+
response = RedirectResponse("/admin", status_code=303)
|
447 |
+
print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
|
448 |
+
return response
|
449 |
+
|
450 |
+
|
451 |
+
@app.get("/logout")
|
452 |
+
def logout(sess):
|
453 |
+
sess["auth"] = False
|
454 |
+
return RedirectResponse("/")
|
455 |
+
|
456 |
+
|
457 |
+
def replace_hi_with_strong(text):
|
458 |
+
parts = re.split(r"(<hi>|</hi>)", text)
|
459 |
+
elements = []
|
460 |
+
open_tag = False
|
461 |
+
for part in parts:
|
462 |
+
if part == "<hi>":
|
463 |
+
open_tag = True
|
464 |
+
elif part == "</hi>":
|
465 |
+
open_tag = False
|
466 |
+
elif open_tag:
|
467 |
+
elements.append(Mark(part))
|
468 |
+
else:
|
469 |
+
elements.append(part)
|
470 |
+
return elements
|
471 |
+
|
472 |
+
|
473 |
+
def log_query_to_db(query, ranking, sess):
|
474 |
+
queries.insert(
|
475 |
+
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
476 |
+
)
|
477 |
+
if 'user_id' not in sess:
|
478 |
+
sess['user_id'] = str(uuid.uuid4())
|
479 |
+
|
480 |
+
if 'queries' not in sess:
|
481 |
+
sess['queries'] = []
|
482 |
+
|
483 |
+
query_data = {
|
484 |
+
'query': query,
|
485 |
+
'ranking': ranking,
|
486 |
+
'timestamp': int(time.time())
|
487 |
+
}
|
488 |
+
sess['queries'].append(query_data)
|
489 |
+
|
490 |
+
# Limit the number of queries stored in the session to prevent it from growing too large
|
491 |
+
sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
|
492 |
+
|
493 |
+
return query_data
|
494 |
+
|
495 |
+
|
496 |
+
def parse_results(records):
|
497 |
+
return [
|
498 |
+
Article(
|
499 |
+
Header(
|
500 |
+
H2(
|
501 |
+
A(
|
502 |
+
result["title"],
|
503 |
+
hx_get=f"/document/{result['id']}",
|
504 |
+
hx_target="#results",
|
505 |
+
)
|
506 |
+
)
|
507 |
+
),
|
508 |
+
Div(
|
509 |
+
P(
|
510 |
+
*replace_hi_with_strong(
|
511 |
+
result["body"][:300] + "..."
|
512 |
+
), # Display first 300 characters of body
|
513 |
+
),
|
514 |
+
Div(
|
515 |
+
# Button with "Show more" - center align
|
516 |
+
Button(
|
517 |
+
"Show more",
|
518 |
+
hx_post=f"/expand/{result['id']}?expand=true",
|
519 |
+
hx_target=f"#{result['id']}",
|
520 |
+
hx_include=f"#{result['id']}-full",
|
521 |
+
cls="outline secondary",
|
522 |
+
# Style to fill whole width of parent div
|
523 |
+
style="width: 100%;",
|
524 |
+
),
|
525 |
+
style="text-align: center;",
|
526 |
+
),
|
527 |
+
id=result["id"],
|
528 |
+
),
|
529 |
+
Hidden(result["body"], id=f"{result['id']}-full"),
|
530 |
+
)
|
531 |
+
for result in records
|
532 |
+
]
|
533 |
+
|
534 |
+
|
535 |
+
@app.post("/expand/{docid}")
|
536 |
+
async def expand(request: Request, docid: str, expand: bool):
|
537 |
+
print(f"Expanding {docid}")
|
538 |
+
form_data = await request.form()
|
539 |
+
result = form_data.get(f"{docid}-full")
|
540 |
+
if not expand:
|
541 |
+
result = result[:300] + "..."
|
542 |
+
return (
|
543 |
+
Div(
|
544 |
+
P(
|
545 |
+
*replace_hi_with_strong(result), # Display full body
|
546 |
+
),
|
547 |
+
Div(
|
548 |
+
# Button with "Show less" - center align
|
549 |
+
Button(
|
550 |
+
"Show less" if expand else "Show more",
|
551 |
+
hx_post=f"/expand/{docid}?expand="
|
552 |
+
+ ("false" if expand else "true"),
|
553 |
+
hx_target=f"#{docid}",
|
554 |
+
hx_include=f"#{docid}-full",
|
555 |
+
cls="outline secondary",
|
556 |
+
# Style to fill whole width of parent div
|
557 |
+
style="width: 100%;",
|
558 |
+
),
|
559 |
+
style="text-align: center;",
|
560 |
+
),
|
561 |
+
id=docid,
|
562 |
+
),
|
563 |
+
)
|
564 |
+
|
565 |
+
|
566 |
+
# Returns tuple of (yql, body(dict)) based on the ranking profile
|
567 |
+
def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
|
568 |
+
if ranking == RankProfile.bm25:
|
569 |
+
yql = "select * from sources * where userQuery() limit 10"
|
570 |
+
body = {}
|
571 |
+
elif ranking == RankProfile.semantic:
|
572 |
+
yql = "select * from sources * where ({targetHits:10}nearestNeighbor(embedding,q)) limit 10"
|
573 |
+
body = {"input.query(q)": f"embed({userquery})"}
|
574 |
+
elif ranking == RankProfile.fusion:
|
575 |
+
yql = "select * from sources * where rank({targetHits:1000}nearestNeighbor(embedding,q), userQuery()) limit 10"
|
576 |
+
body = {"input.query(q)": f"embed({userquery})"}
|
577 |
+
return yql, body
|
578 |
+
|
579 |
+
|
580 |
+
@app.get("/search")
|
581 |
+
async def search(userquery: str, ranking: str, sess):
|
582 |
+
print(sess)
|
583 |
+
quoted = quote(userquery) + "&ranking=" + ranking
|
584 |
+
log_query_to_db(userquery, ranking, sess)
|
585 |
+
yql, body = get_yql(ranking, userquery)
|
586 |
+
async with vespa_app.asyncio() as session:
|
587 |
+
resp = await session.query(
|
588 |
+
yql=yql,
|
589 |
+
query=userquery,
|
590 |
+
hits=10,
|
591 |
+
ranking=str(ranking),
|
592 |
+
body=body,
|
593 |
+
)
|
594 |
+
records = []
|
595 |
+
fields = ["id", "title", "body"]
|
596 |
+
for hit in resp.hits:
|
597 |
+
record = {}
|
598 |
+
for field in fields:
|
599 |
+
record[field] = hit["fields"][field]
|
600 |
+
records.append(record)
|
601 |
+
results = parse_results(records)
|
602 |
+
json_dump = json.dumps(resp.get_json(), indent=4)
|
603 |
+
return Div(
|
604 |
+
spinner_div(),
|
605 |
+
# Accordion (with Details)
|
606 |
+
Details(
|
607 |
+
Summary("Full JSON response"),
|
608 |
+
Div(
|
609 |
+
f"""```json\n{json_dump}\n```""",
|
610 |
+
cls="marked",
|
611 |
+
),
|
612 |
+
),
|
613 |
+
H2(
|
614 |
+
"Search Results",
|
615 |
+
),
|
616 |
+
Div(
|
617 |
+
*results,
|
618 |
+
id="all-searchresults",
|
619 |
+
),
|
620 |
+
)
|
621 |
+
|
622 |
+
|
623 |
+
@app.get("/download_csv")
|
624 |
+
def download_csv(auth):
|
625 |
+
queries_dict = list(db.query("SELECT * FROM queries"))
|
626 |
+
queries = [Query(**query) for query in queries_dict]
|
627 |
+
|
628 |
+
# Create CSV in memory
|
629 |
+
csv_file = StringIO()
|
630 |
+
csv_writer = csv.writer(csv_file)
|
631 |
+
csv_writer.writerow(["Query", "Session ID", "Timestamp"])
|
632 |
+
for query in queries:
|
633 |
+
csv_writer.writerow([query.query, query.sess_id, query.timestamp])
|
634 |
+
|
635 |
+
# Move to the beginning of the StringIO object
|
636 |
+
csv_file.seek(0)
|
637 |
+
|
638 |
+
# Save CSV to a temporary file
|
639 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
640 |
+
temp_file.write(csv_file.getvalue().encode("utf-8"))
|
641 |
+
temp_file.close()
|
642 |
+
|
643 |
+
return FileResponse(
|
644 |
+
temp_file.name,
|
645 |
+
filename="queries.csv",
|
646 |
+
media_type="text/csv",
|
647 |
+
content_disposition_type="attachment",
|
648 |
+
)
|
649 |
+
|
650 |
+
|
651 |
+
@app.get("/admin")
|
652 |
+
def get_admin(auth, page: int = 1):
|
653 |
+
limit = 15
|
654 |
+
offset = (page - 1) * limit
|
655 |
+
total_queries_result = list(
|
656 |
+
db.query("SELECT COUNT(*) AS count FROM queries ORDER BY timestamp DESC")
|
657 |
+
)
|
658 |
+
total_queries = total_queries_result[0]["count"]
|
659 |
+
queries_dict = list(
|
660 |
+
db.query(f"SELECT * FROM queries LIMIT {limit} OFFSET {offset}")
|
661 |
+
)
|
662 |
+
queries = [Query(**query) for query in queries_dict]
|
663 |
+
|
664 |
+
total_pages = (
|
665 |
+
total_queries + limit - 1
|
666 |
+
) // limit # Calculate total number of pages
|
667 |
+
|
668 |
+
# Define the range of pages to display
|
669 |
+
page_window = 5 # Number of pages to display at once
|
670 |
+
start_page = max(1, page - page_window // 2)
|
671 |
+
end_page = min(total_pages, start_page + page_window - 1)
|
672 |
+
|
673 |
+
# Adjust the start and end pages if they exceed the limits
|
674 |
+
if end_page - start_page < page_window:
|
675 |
+
start_page = max(1, end_page - page_window + 1)
|
676 |
+
|
677 |
+
# Pagination controls with "First", "Previous", "Next", and "Last"
|
678 |
+
pagination_controls = Div(
|
679 |
+
A(
|
680 |
+
"First",
|
681 |
+
href="/admin?page=1",
|
682 |
+
style="margin: 5px;"
|
683 |
+
if page > 1
|
684 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
685 |
+
),
|
686 |
+
A(
|
687 |
+
"Previous",
|
688 |
+
href=f"/admin?page={page - 1}",
|
689 |
+
style="margin: 5px;"
|
690 |
+
if page > 1
|
691 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
692 |
+
),
|
693 |
+
*[
|
694 |
+
A(
|
695 |
+
f"{i}",
|
696 |
+
href=f"/admin?page={i}",
|
697 |
+
style="margin: 5px;"
|
698 |
+
if i != page
|
699 |
+
else "margin: 5px; font-weight: bold;",
|
700 |
+
)
|
701 |
+
for i in range(start_page, end_page + 1)
|
702 |
+
],
|
703 |
+
A(
|
704 |
+
"Next",
|
705 |
+
href=f"/admin?page={page + 1}",
|
706 |
+
style="margin: 5px;"
|
707 |
+
if page < total_pages
|
708 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
709 |
+
),
|
710 |
+
A(
|
711 |
+
"Last",
|
712 |
+
href=f"/admin?page={total_pages}",
|
713 |
+
style="margin: 5px;"
|
714 |
+
if page < total_pages
|
715 |
+
else "margin: 5px; color: grey; pointer-events: none;",
|
716 |
+
),
|
717 |
+
style="text-align: center; margin: 20px;",
|
718 |
+
)
|
719 |
+
|
720 |
+
# Total pages indication
|
721 |
+
total_pages_indicator = Div(
|
722 |
+
f"Page {page} of {total_pages}",
|
723 |
+
style="text-align: center; margin: 10px;",
|
724 |
+
)
|
725 |
+
|
726 |
+
return (
|
727 |
+
Title("Admin"),
|
728 |
+
get_navbar(auth),
|
729 |
+
Main(
|
730 |
+
Div(
|
731 |
+
A(
|
732 |
+
I(cls="fa fa-arrow-left"),
|
733 |
+
"Back",
|
734 |
+
href="/",
|
735 |
+
title="Back to main page",
|
736 |
+
style="margin: 10px;",
|
737 |
+
),
|
738 |
+
style="margin: 10px;",
|
739 |
+
),
|
740 |
+
H2("Queries"),
|
741 |
+
# Table of all queries
|
742 |
+
Table(
|
743 |
+
Thead(
|
744 |
+
Tr(
|
745 |
+
Th("Query"),
|
746 |
+
Th("Session ID"),
|
747 |
+
Th("Datetime"),
|
748 |
+
)
|
749 |
+
),
|
750 |
+
Tbody(
|
751 |
+
*[
|
752 |
+
Tr(
|
753 |
+
Td(query.query),
|
754 |
+
Td(query.sess_id),
|
755 |
+
Td(query.get_datetime()),
|
756 |
+
)
|
757 |
+
for query in queries
|
758 |
+
],
|
759 |
+
),
|
760 |
+
cls="striped",
|
761 |
+
),
|
762 |
+
total_pages_indicator, # Include the total pages indicator here
|
763 |
+
pagination_controls,
|
764 |
+
Div(
|
765 |
+
A(
|
766 |
+
I(cls="fa fa-download fa-2x"),
|
767 |
+
" Download CSV",
|
768 |
+
href="/download_csv",
|
769 |
+
style="margin: 10px; float: right;",
|
770 |
+
title="Download queries as CSV",
|
771 |
+
),
|
772 |
+
style="text-align: right; margin: 20px;",
|
773 |
+
),
|
774 |
+
style="width: 80%; margin: 40px auto;",
|
775 |
+
),
|
776 |
+
)
|
777 |
+
|
778 |
+
|
779 |
+
@app.get("/source")
|
780 |
+
def get_source(auth, sess):
|
781 |
+
# Back icon to go back to main page in top left corner
|
782 |
+
return (
|
783 |
+
Title("Source code"),
|
784 |
+
get_navbar(auth),
|
785 |
+
Main(
|
786 |
+
Div(
|
787 |
+
A(
|
788 |
+
I(cls="fa fa-arrow-left"),
|
789 |
+
"Back",
|
790 |
+
href="/",
|
791 |
+
title="Back to main page",
|
792 |
+
style="margin: 10px;",
|
793 |
+
),
|
794 |
+
Div(
|
795 |
+
f"""### `main.py`\n### This is the complete source code for this app \n```python\n{SOURCE}\n```""",
|
796 |
+
cls="marked",
|
797 |
+
style="margin: 10px;",
|
798 |
+
),
|
799 |
+
style="width: 80%; margin: 40px auto;",
|
800 |
+
),
|
801 |
+
),
|
802 |
+
)
|
803 |
+
|
804 |
+
|
805 |
+
@app.get("/about")
|
806 |
+
def get_about(auth, sess):
|
807 |
+
# Strip everything before the FIRST # in the README
|
808 |
+
stripped_readme = re.sub(
|
809 |
+
r"^.*?(?=# FastHTML Vespa frontend)", "", README, flags=re.DOTALL
|
810 |
+
)
|
811 |
+
|
812 |
+
return (
|
813 |
+
Title("About this app"),
|
814 |
+
get_navbar(auth),
|
815 |
+
Main(
|
816 |
+
Div(
|
817 |
+
A(
|
818 |
+
I(cls="fa fa-arrow-left"),
|
819 |
+
"Back",
|
820 |
+
href="/",
|
821 |
+
title="Back to main page",
|
822 |
+
style="margin: 10px;",
|
823 |
+
),
|
824 |
+
Div(
|
825 |
+
stripped_readme,
|
826 |
+
cls="marked",
|
827 |
+
style="margin: 10px;",
|
828 |
+
),
|
829 |
+
style="width: 80%; margin: 40px auto;",
|
830 |
+
),
|
831 |
+
),
|
832 |
+
)
|
833 |
+
|
834 |
+
|
835 |
+
@app.get("/document/{docid}")
|
836 |
+
def get_document(docid: str, sess):
|
837 |
+
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
838 |
+
doc = resp.json
|
839 |
+
# Link with Back to search results at top of page
|
840 |
+
last_query = sess.get('queries', [{}])[-1].get('query', '')
|
841 |
+
return Main(
|
842 |
+
Div(
|
843 |
+
A(
|
844 |
+
I(cls="fa fa-arrow-left"),
|
845 |
+
"Back to search results",
|
846 |
+
hx_get=f"/search?userquery={last_query}",
|
847 |
+
hx_target="#results",
|
848 |
+
style="margin: 10px;",
|
849 |
+
),
|
850 |
+
H2(doc["fields"]["title"], style="margin: 10px;"),
|
851 |
+
P(doc["fields"]["body"], cls="marked"),
|
852 |
+
),
|
853 |
+
)
|
854 |
+
|
855 |
+
|
856 |
+
if not DEV_MODE:
|
857 |
+
try:
|
858 |
+
setup_hf_backup(app)
|
859 |
+
except Exception as e:
|
860 |
+
print(f"Error setting up hf backup: {e}")
|
861 |
+
serve()
|
main.py
CHANGED
@@ -57,7 +57,6 @@ import tempfile
|
|
57 |
from enum import Enum
|
58 |
from typing import Tuple as T
|
59 |
from urllib.parse import quote
|
60 |
-
import uuid
|
61 |
|
62 |
DEV_MODE = False
|
63 |
|
@@ -165,32 +164,14 @@ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
|
|
165 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
166 |
return response
|
167 |
|
168 |
-
class SessionLoggingMiddleware(BaseHTTPMiddleware):
|
169 |
-
async def dispatch(self, request, call_next):
|
170 |
-
print(f"Before request: Session data: {request.session}")
|
171 |
-
response = await call_next(request)
|
172 |
-
print(f"After request: Session data: {request.session}")
|
173 |
-
return response
|
174 |
-
|
175 |
-
class DebugSessionMiddleware(SessionMiddleware):
|
176 |
-
async def __call__(self, scope, receive, send):
|
177 |
-
print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
|
178 |
-
await super().__call__(scope, receive, send)
|
179 |
-
print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
|
180 |
-
|
181 |
-
from starlette.middleware.cors import CORSMiddleware
|
182 |
|
183 |
middlewares = [
|
184 |
Middleware(
|
185 |
SessionMiddleware,
|
186 |
secret_key=get_key(fname=sess_key_path),
|
187 |
max_age=3600,
|
188 |
-
#same_site='lax',
|
189 |
),
|
190 |
-
Middleware(CORSMiddleware, allow_origins=['*']),
|
191 |
Middleware(XFrameOptionsMiddleware),
|
192 |
-
Middleware(SessionLoggingMiddleware),
|
193 |
-
#Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
|
194 |
]
|
195 |
bware = Beforeware(
|
196 |
user_auth_before,
|
@@ -314,6 +295,7 @@ def get(sess):
|
|
314 |
queries = [
|
315 |
"Breast Cancer Cells Feed on Cholesterol",
|
316 |
"Treating Asthma With Plants vs. Pills",
|
|
|
317 |
"Testing Turmeric on Smokers",
|
318 |
"The Role of Pesticides in Parkinson's Disease",
|
319 |
]
|
@@ -442,10 +424,9 @@ def post(login: Login, sess):
|
|
442 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
443 |
# Incorrect password - add error message
|
444 |
return RedirectResponse("/login?error=True", status_code=303)
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
return response
|
449 |
|
450 |
|
451 |
@app.get("/logout")
|
@@ -471,26 +452,9 @@ def replace_hi_with_strong(text):
|
|
471 |
|
472 |
|
473 |
def log_query_to_db(query, ranking, sess):
|
474 |
-
queries.insert(
|
475 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
476 |
)
|
477 |
-
if 'user_id' not in sess:
|
478 |
-
sess['user_id'] = str(uuid.uuid4())
|
479 |
-
|
480 |
-
if 'queries' not in sess:
|
481 |
-
sess['queries'] = []
|
482 |
-
|
483 |
-
query_data = {
|
484 |
-
'query': query,
|
485 |
-
'ranking': ranking,
|
486 |
-
'timestamp': int(time.time())
|
487 |
-
}
|
488 |
-
sess['queries'].append(query_data)
|
489 |
-
|
490 |
-
# Limit the number of queries stored in the session to prevent it from growing too large
|
491 |
-
sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
|
492 |
-
|
493 |
-
return query_data
|
494 |
|
495 |
|
496 |
def parse_results(records):
|
@@ -580,7 +544,12 @@ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
|
|
580 |
@app.get("/search")
|
581 |
async def search(userquery: str, ranking: str, sess):
|
582 |
print(sess)
|
|
|
|
|
583 |
quoted = quote(userquery) + "&ranking=" + ranking
|
|
|
|
|
|
|
584 |
log_query_to_db(userquery, ranking, sess)
|
585 |
yql, body = get_yql(ranking, userquery)
|
586 |
async with vespa_app.asyncio() as session:
|
@@ -837,13 +806,12 @@ def get_document(docid: str, sess):
|
|
837 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
838 |
doc = resp.json
|
839 |
# Link with Back to search results at top of page
|
840 |
-
last_query = sess.get('queries', [{}])[-1].get('query', '')
|
841 |
return Main(
|
842 |
Div(
|
843 |
A(
|
844 |
I(cls="fa fa-arrow-left"),
|
845 |
"Back to search results",
|
846 |
-
hx_get=f"/search?userquery={
|
847 |
hx_target="#results",
|
848 |
style="margin: 10px;",
|
849 |
),
|
|
|
57 |
from enum import Enum
|
58 |
from typing import Tuple as T
|
59 |
from urllib.parse import quote
|
|
|
60 |
|
61 |
DEV_MODE = False
|
62 |
|
|
|
164 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
165 |
return response
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
middlewares = [
|
169 |
Middleware(
|
170 |
SessionMiddleware,
|
171 |
secret_key=get_key(fname=sess_key_path),
|
172 |
max_age=3600,
|
|
|
173 |
),
|
|
|
174 |
Middleware(XFrameOptionsMiddleware),
|
|
|
|
|
175 |
]
|
176 |
bware = Beforeware(
|
177 |
user_auth_before,
|
|
|
295 |
queries = [
|
296 |
"Breast Cancer Cells Feed on Cholesterol",
|
297 |
"Treating Asthma With Plants vs. Pills",
|
298 |
+
"Alkylphenol Endocrine Disruptors",
|
299 |
"Testing Turmeric on Smokers",
|
300 |
"The Role of Pesticides in Parkinson's Disease",
|
301 |
]
|
|
|
424 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
425 |
# Incorrect password - add error message
|
426 |
return RedirectResponse("/login?error=True", status_code=303)
|
427 |
+
sess["auth"] = True
|
428 |
+
print(f"Sess after login: {sess}")
|
429 |
+
return RedirectResponse("/admin", status_code=303)
|
|
|
430 |
|
431 |
|
432 |
@app.get("/logout")
|
|
|
452 |
|
453 |
|
454 |
def log_query_to_db(query, ranking, sess):
|
455 |
+
return queries.insert(
|
456 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
457 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
|
459 |
|
460 |
def parse_results(records):
|
|
|
544 |
@app.get("/search")
|
545 |
async def search(userquery: str, ranking: str, sess):
|
546 |
print(sess)
|
547 |
+
if "queries" not in sess:
|
548 |
+
sess["queries"] = []
|
549 |
quoted = quote(userquery) + "&ranking=" + ranking
|
550 |
+
sess["queries"].append(quoted)
|
551 |
+
print(f"Searching for: {userquery}")
|
552 |
+
print(f"Ranking: {ranking}")
|
553 |
log_query_to_db(userquery, ranking, sess)
|
554 |
yql, body = get_yql(ranking, userquery)
|
555 |
async with vespa_app.asyncio() as session:
|
|
|
806 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
807 |
doc = resp.json
|
808 |
# Link with Back to search results at top of page
|
|
|
809 |
return Main(
|
810 |
Div(
|
811 |
A(
|
812 |
I(cls="fa fa-arrow-left"),
|
813 |
"Back to search results",
|
814 |
+
hx_get=f"/search?userquery={sess['queries'][-1]}",
|
815 |
hx_target="#results",
|
816 |
style="margin: 10px;",
|
817 |
),
|