Spaces:
Running
Running
Abhipsha Das
commited on
initial spaces deploy
Browse files- README.md +12 -7
- app.py +14 -0
- config.py +445 -0
- requirements.txt +8 -0
- scripts/__init__.py +6 -0
- scripts/__pycache__/__init__.cpython-311.pyc +0 -0
- scripts/__pycache__/create_db.cpython-311.pyc +0 -0
- scripts/__pycache__/run_db_interface.cpython-311.pyc +0 -0
- scripts/__pycache__/run_db_interface_improved.cpython-311.pyc +0 -0
- scripts/create_db.py +246 -0
- scripts/run_db_interface.py +704 -0
- scripts/run_db_interface_basic.py +361 -0
- scripts/run_db_interface_js.py +0 -0
README.md
CHANGED
@@ -1,14 +1,19 @@
|
|
1 |
---
|
2 |
-
title: Surveyor
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: openrail
|
11 |
-
short_description: Interface for exploring scientific concepts with KGs
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Surveyor
|
3 |
+
emoji: π
|
4 |
+
colorFrom: blue
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.40.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
# Surveyor
|
13 |
+
|
14 |
+
An interactive interface for querying and visualizing scientific paper databases with concept co-occurrence graphs.
|
15 |
+
## Features
|
16 |
+
- Interactive concept co-occurrence graphs
|
17 |
+
- SQL query interface with pre-built queries
|
18 |
+
- Support for multiple scientific domains
|
19 |
+
- Graph filtering and highlighting
|
app.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
# Add the project root directory to Python path
|
5 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
if ROOT_DIR not in sys.path:
|
7 |
+
sys.path.insert(0, ROOT_DIR)
|
8 |
+
|
9 |
+
from scripts.run_db_interface import create_demo
|
10 |
+
|
11 |
+
demo = create_demo()
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
demo.launch()
|
config.py
ADDED
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_MODEL_ID = "Meta-Llama-3-70B-Instruct"
|
2 |
+
DEFAULT_INTERFACE_MODEL_ID = "NumbersStation/nsql-llama-2-7B"
|
3 |
+
DEFAULT_KIND = "json"
|
4 |
+
DEFAULT_TEMPERATURE = 0.6
|
5 |
+
DEFAULT_TOP_P = 0.95
|
6 |
+
DEFAULT_FEW_SHOT_NUM = 3
|
7 |
+
DEFAULT_FEW_SHOT_SELECTION = "random"
|
8 |
+
DEFAULT_SAVE_INTERVAL = 3
|
9 |
+
DEFAULT_RES_DIR = "data/results"
|
10 |
+
DEFAULT_LOG_DIR = "logs"
|
11 |
+
DEFAULT_TABLES_DIR = "data/databases"
|
12 |
+
|
13 |
+
COOCCURRENCE_QUERY = """
|
14 |
+
WITH concept_pairs AS (
|
15 |
+
SELECT p1.concept AS concept1, p2.concept AS concept2, p1.paper_id, p1.tag_type
|
16 |
+
FROM predictions p1
|
17 |
+
JOIN predictions p2 ON p1.paper_id = p2.paper_id AND p1.concept < p2.concept
|
18 |
+
WHERE p1.tag_type = p2.tag_type
|
19 |
+
)
|
20 |
+
SELECT concept1, concept2, tag_type, COUNT(DISTINCT paper_id) AS co_occurrences
|
21 |
+
FROM concept_pairs
|
22 |
+
GROUP BY concept1, concept2, tag_type
|
23 |
+
HAVING co_occurrences > 5
|
24 |
+
ORDER BY co_occurrences DESC;
|
25 |
+
"""
|
26 |
+
|
27 |
+
canned_queries = [
|
28 |
+
(
|
29 |
+
"Modalities in Physics and Astronomy papers",
|
30 |
+
"""
|
31 |
+
SELECT DISTINCT LOWER(concept) AS concept
|
32 |
+
FROM predictions
|
33 |
+
JOIN (
|
34 |
+
SELECT paper_id, url
|
35 |
+
FROM papers
|
36 |
+
WHERE primary_category LIKE '%physics.space-ph%'
|
37 |
+
OR primary_category LIKE '%astro-ph.%'
|
38 |
+
) AS paper_ids
|
39 |
+
ON predictions.paper_id = paper_ids.paper_id
|
40 |
+
WHERE predictions.tag_type = 'modality'
|
41 |
+
""",
|
42 |
+
),
|
43 |
+
(
|
44 |
+
"Datasets in Evolutionary Biology that use PDEs",
|
45 |
+
"""
|
46 |
+
WITH pde_predictions AS (
|
47 |
+
SELECT paper_id, concept AS pde_concept, tag_type AS pde_tag_type
|
48 |
+
FROM predictions
|
49 |
+
WHERE tag_type IN ('method', 'model')
|
50 |
+
AND (
|
51 |
+
LOWER(concept) LIKE '%pde%'
|
52 |
+
OR LOWER(concept) LIKE '%partial differential equation%'
|
53 |
+
)
|
54 |
+
)
|
55 |
+
SELECT DISTINCT
|
56 |
+
papers.paper_id,
|
57 |
+
papers.url,
|
58 |
+
LOWER(p_dataset.concept) AS dataset,
|
59 |
+
pde_predictions.pde_concept AS pde_related_concept,
|
60 |
+
pde_predictions.pde_tag_type AS pde_related_type
|
61 |
+
FROM papers
|
62 |
+
JOIN pde_predictions ON papers.paper_id = pde_predictions.paper_id
|
63 |
+
LEFT JOIN predictions p_dataset ON papers.paper_id = p_dataset.paper_id
|
64 |
+
WHERE papers.primary_category LIKE '%q-bio.PE%'
|
65 |
+
AND (p_dataset.tag_type = 'dataset' OR p_dataset.tag_type IS NULL)
|
66 |
+
ORDER BY papers.paper_id, dataset, pde_related_concept;
|
67 |
+
""",
|
68 |
+
),
|
69 |
+
(
|
70 |
+
"Trends in objects of study in Cosmology since 2019",
|
71 |
+
"""
|
72 |
+
|
73 |
+
SELECT
|
74 |
+
substr(papers.updated_on, 2, 4) as year,
|
75 |
+
predictions.concept as object,
|
76 |
+
COUNT(DISTINCT papers.paper_id) as paper_count
|
77 |
+
FROM
|
78 |
+
papers
|
79 |
+
JOIN
|
80 |
+
predictions ON papers.paper_id = predictions.paper_id
|
81 |
+
WHERE
|
82 |
+
predictions.tag_type = 'object'
|
83 |
+
AND CAST(SUBSTR(papers.updated_on, 2, 4) AS INTEGER) >= 2019
|
84 |
+
GROUP BY
|
85 |
+
year, object
|
86 |
+
ORDER BY
|
87 |
+
year DESC, paper_count DESC;
|
88 |
+
""",
|
89 |
+
),
|
90 |
+
(
|
91 |
+
"New datasets in fluid dynamics since 2020",
|
92 |
+
"""
|
93 |
+
WITH ranked_datasets AS (
|
94 |
+
SELECT
|
95 |
+
p.paper_id,
|
96 |
+
p.url,
|
97 |
+
pred.concept AS dataset,
|
98 |
+
p.updated_on,
|
99 |
+
ROW_NUMBER() OVER (PARTITION BY pred.concept ORDER BY p.updated_on ASC) AS rn
|
100 |
+
FROM
|
101 |
+
papers p
|
102 |
+
JOIN
|
103 |
+
predictions pred ON p.paper_id = pred.paper_id
|
104 |
+
WHERE
|
105 |
+
pred.tag_type = 'dataset'
|
106 |
+
AND p.primary_category LIKE '%physics.flu-dyn%'
|
107 |
+
AND CAST(SUBSTR(p.updated_on, 2, 4) AS INTEGER) >= 2020
|
108 |
+
)
|
109 |
+
SELECT
|
110 |
+
paper_id,
|
111 |
+
url,
|
112 |
+
dataset,
|
113 |
+
updated_on
|
114 |
+
FROM
|
115 |
+
ranked_datasets
|
116 |
+
WHERE
|
117 |
+
rn = 1
|
118 |
+
ORDER BY
|
119 |
+
updated_on ASC
|
120 |
+
""",
|
121 |
+
),
|
122 |
+
(
|
123 |
+
"Evolutionary biology datasets that use spatiotemporal dynamics",
|
124 |
+
"""
|
125 |
+
WITH evo_bio_papers AS (
|
126 |
+
SELECT paper_id
|
127 |
+
FROM papers
|
128 |
+
WHERE primary_category LIKE '%q-bio.PE%'
|
129 |
+
),
|
130 |
+
spatiotemporal_keywords AS (
|
131 |
+
SELECT 'spatio-temporal' AS keyword
|
132 |
+
UNION SELECT 'spatiotemporal'
|
133 |
+
UNION SELECT 'spatio-temporal'
|
134 |
+
UNION SELECT 'spatial and temporal'
|
135 |
+
UNION SELECT 'space-time'
|
136 |
+
UNION SELECT 'geographic distribution'
|
137 |
+
UNION SELECT 'phylogeograph'
|
138 |
+
UNION SELECT 'biogeograph'
|
139 |
+
UNION SELECT 'dispersal'
|
140 |
+
UNION SELECT 'migration'
|
141 |
+
UNION SELECT 'range expansion'
|
142 |
+
UNION SELECT 'population dynamics'
|
143 |
+
)
|
144 |
+
SELECT DISTINCT
|
145 |
+
p.paper_id,
|
146 |
+
p.updated_on,
|
147 |
+
p.abstract,
|
148 |
+
d.concept AS dataset,
|
149 |
+
GROUP_CONCAT(DISTINCT stk.keyword) AS spatiotemporal_keywords_found
|
150 |
+
FROM
|
151 |
+
evo_bio_papers ebp
|
152 |
+
JOIN
|
153 |
+
papers p ON ebp.paper_id = p.paper_id
|
154 |
+
JOIN
|
155 |
+
predictions d ON p.paper_id = d.paper_id
|
156 |
+
JOIN
|
157 |
+
predictions st ON p.paper_id = st.paper_id
|
158 |
+
JOIN
|
159 |
+
spatiotemporal_keywords stk
|
160 |
+
WHERE
|
161 |
+
d.tag_type = 'dataset'
|
162 |
+
AND st.tag_type = 'modality'
|
163 |
+
AND LOWER(st.concept) LIKE '%' || stk.keyword || '%'
|
164 |
+
GROUP BY
|
165 |
+
p.paper_id, p.updated_on, p.abstract, d.concept
|
166 |
+
ORDER BY
|
167 |
+
p.updated_on DESC
|
168 |
+
""",
|
169 |
+
),
|
170 |
+
(
|
171 |
+
"What percentage of papers use only galaxy or spectra, or both or neither?",
|
172 |
+
"""
|
173 |
+
WITH paper_modalities AS (
|
174 |
+
SELECT
|
175 |
+
p.paper_id,
|
176 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_galaxy_images,
|
177 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra
|
178 |
+
FROM
|
179 |
+
papers p
|
180 |
+
LEFT JOIN
|
181 |
+
predictions pred ON p.paper_id = pred.paper_id
|
182 |
+
WHERE
|
183 |
+
p.primary_category LIKE '%astro-ph%'
|
184 |
+
AND pred.tag_type = 'modality'
|
185 |
+
GROUP BY
|
186 |
+
p.paper_id
|
187 |
+
),
|
188 |
+
categorized_papers AS (
|
189 |
+
SELECT
|
190 |
+
CASE
|
191 |
+
WHEN uses_galaxy_images = 1 AND uses_spectra = 1 THEN 'Both'
|
192 |
+
WHEN uses_galaxy_images = 1 THEN 'Only Galaxy Images'
|
193 |
+
WHEN uses_spectra = 1 THEN 'Only Spectra'
|
194 |
+
ELSE 'Neither'
|
195 |
+
END AS category,
|
196 |
+
COUNT(*) AS paper_count
|
197 |
+
FROM
|
198 |
+
paper_modalities
|
199 |
+
GROUP BY
|
200 |
+
CASE
|
201 |
+
WHEN uses_galaxy_images = 1 AND uses_spectra = 1 THEN 'Both'
|
202 |
+
WHEN uses_galaxy_images = 1 THEN 'Only Galaxy Images'
|
203 |
+
WHEN uses_spectra = 1 THEN 'Only Spectra'
|
204 |
+
ELSE 'Neither'
|
205 |
+
END
|
206 |
+
)
|
207 |
+
SELECT
|
208 |
+
category,
|
209 |
+
paper_count,
|
210 |
+
ROUND(CAST(paper_count AS FLOAT) / (SELECT SUM(paper_count) FROM categorized_papers) * 100, 2) AS percentage
|
211 |
+
FROM
|
212 |
+
categorized_papers
|
213 |
+
ORDER BY
|
214 |
+
paper_count DESC
|
215 |
+
""",
|
216 |
+
),
|
217 |
+
(
|
218 |
+
"What are all the next highest data modalities after images and spectra?",
|
219 |
+
"""
|
220 |
+
SELECT
|
221 |
+
LOWER(concept) AS modality,
|
222 |
+
COUNT(DISTINCT paper_id) AS usage_count
|
223 |
+
FROM
|
224 |
+
predictions
|
225 |
+
WHERE
|
226 |
+
tag_type = 'modality'
|
227 |
+
AND LOWER(concept) NOT LIKE '%imag%'
|
228 |
+
AND LOWER(concept) NOT LIKE '%spectr%'
|
229 |
+
GROUP BY
|
230 |
+
LOWER(concept)
|
231 |
+
ORDER BY
|
232 |
+
usage_count DESC
|
233 |
+
""",
|
234 |
+
),
|
235 |
+
(
|
236 |
+
"If we include the next biggest data modality, how much does coverage change?",
|
237 |
+
"""
|
238 |
+
WITH modality_counts AS (
|
239 |
+
SELECT
|
240 |
+
LOWER(concept) AS modality,
|
241 |
+
COUNT(DISTINCT paper_id) AS usage_count
|
242 |
+
FROM
|
243 |
+
predictions
|
244 |
+
WHERE
|
245 |
+
tag_type = 'modality'
|
246 |
+
AND LOWER(concept) NOT LIKE '%imag%'
|
247 |
+
AND LOWER(concept) NOT LIKE '%spectr%'
|
248 |
+
GROUP BY
|
249 |
+
LOWER(concept)
|
250 |
+
ORDER BY
|
251 |
+
usage_count DESC
|
252 |
+
LIMIT 1
|
253 |
+
),
|
254 |
+
paper_modalities AS (
|
255 |
+
SELECT
|
256 |
+
p.paper_id,
|
257 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_galaxy_images,
|
258 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra,
|
259 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE (SELECT '%' || modality || '%' FROM modality_counts) THEN 1 ELSE 0 END) AS uses_third_modality
|
260 |
+
FROM
|
261 |
+
papers p
|
262 |
+
LEFT JOIN
|
263 |
+
predictions pred ON p.paper_id = pred.paper_id
|
264 |
+
WHERE
|
265 |
+
p.primary_category LIKE '%astro-ph%'
|
266 |
+
AND pred.tag_type = 'modality'
|
267 |
+
GROUP BY
|
268 |
+
p.paper_id
|
269 |
+
),
|
270 |
+
coverage_before AS (
|
271 |
+
SELECT
|
272 |
+
SUM(CASE WHEN uses_galaxy_images = 1 OR uses_spectra = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
273 |
+
COUNT(*) AS total_papers
|
274 |
+
FROM
|
275 |
+
paper_modalities
|
276 |
+
),
|
277 |
+
coverage_after AS (
|
278 |
+
SELECT
|
279 |
+
SUM(CASE WHEN uses_galaxy_images = 1 OR uses_spectra = 1 OR uses_third_modality = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
280 |
+
COUNT(*) AS total_papers
|
281 |
+
FROM
|
282 |
+
paper_modalities
|
283 |
+
)
|
284 |
+
SELECT
|
285 |
+
(SELECT modality FROM modality_counts) AS third_modality,
|
286 |
+
ROUND(CAST(covered_papers AS FLOAT) / total_papers * 100, 2) AS coverage_before_percent,
|
287 |
+
ROUND(CAST((SELECT covered_papers FROM coverage_after) AS FLOAT) / total_papers * 100, 2) AS coverage_after_percent,
|
288 |
+
ROUND(CAST((SELECT covered_papers FROM coverage_after) AS FLOAT) / total_papers * 100, 2) -
|
289 |
+
ROUND(CAST(covered_papers AS FLOAT) / total_papers * 100, 2) AS coverage_increase_percent
|
290 |
+
FROM
|
291 |
+
coverage_before
|
292 |
+
""",
|
293 |
+
),
|
294 |
+
(
|
295 |
+
"Coverage if we select the next 5 highest modalities?",
|
296 |
+
"""
|
297 |
+
WITH ranked_modalities AS (
|
298 |
+
SELECT
|
299 |
+
LOWER(concept) AS modality,
|
300 |
+
COUNT(DISTINCT paper_id) AS usage_count,
|
301 |
+
ROW_NUMBER() OVER (ORDER BY COUNT(DISTINCT paper_id) DESC) AS rank
|
302 |
+
FROM
|
303 |
+
predictions
|
304 |
+
WHERE
|
305 |
+
tag_type = 'modality'
|
306 |
+
AND LOWER(concept) NOT LIKE '%imag%'
|
307 |
+
AND LOWER(concept) NOT LIKE '%spectr%'
|
308 |
+
GROUP BY
|
309 |
+
LOWER(concept)
|
310 |
+
),
|
311 |
+
paper_modalities AS (
|
312 |
+
SELECT
|
313 |
+
p.paper_id,
|
314 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_images,
|
315 |
+
MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra,
|
316 |
+
MAX(CASE WHEN rm.rank = 1 THEN 1 ELSE 0 END) AS uses_modality_1,
|
317 |
+
MAX(CASE WHEN rm.rank = 2 THEN 1 ELSE 0 END) AS uses_modality_2,
|
318 |
+
MAX(CASE WHEN rm.rank = 3 THEN 1 ELSE 0 END) AS uses_modality_3,
|
319 |
+
MAX(CASE WHEN rm.rank = 4 THEN 1 ELSE 0 END) AS uses_modality_4,
|
320 |
+
MAX(CASE WHEN rm.rank = 5 THEN 1 ELSE 0 END) AS uses_modality_5
|
321 |
+
FROM
|
322 |
+
papers p
|
323 |
+
LEFT JOIN
|
324 |
+
predictions pred ON p.paper_id = pred.paper_id
|
325 |
+
LEFT JOIN
|
326 |
+
ranked_modalities rm ON LOWER(pred.concept) = rm.modality
|
327 |
+
WHERE
|
328 |
+
p.primary_category LIKE '%astro-ph%'
|
329 |
+
AND pred.tag_type = 'modality'
|
330 |
+
GROUP BY
|
331 |
+
p.paper_id
|
332 |
+
),
|
333 |
+
cumulative_coverage AS (
|
334 |
+
SELECT
|
335 |
+
'Images and Spectra' AS modalities,
|
336 |
+
0 AS added_modality_rank,
|
337 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
338 |
+
COUNT(*) AS total_papers
|
339 |
+
FROM
|
340 |
+
paper_modalities
|
341 |
+
|
342 |
+
UNION ALL
|
343 |
+
|
344 |
+
SELECT
|
345 |
+
'Images, Spectra, and Modality 1' AS modalities,
|
346 |
+
1 AS added_modality_rank,
|
347 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
348 |
+
COUNT(*) AS total_papers
|
349 |
+
FROM
|
350 |
+
paper_modalities
|
351 |
+
|
352 |
+
UNION ALL
|
353 |
+
|
354 |
+
SELECT
|
355 |
+
'Images, Spectra, Modality 1, and 2' AS modalities,
|
356 |
+
2 AS added_modality_rank,
|
357 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
358 |
+
COUNT(*) AS total_papers
|
359 |
+
FROM
|
360 |
+
paper_modalities
|
361 |
+
|
362 |
+
UNION ALL
|
363 |
+
|
364 |
+
SELECT
|
365 |
+
'Images, Spectra, Modality 1, 2, and 3' AS modalities,
|
366 |
+
3 AS added_modality_rank,
|
367 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
368 |
+
COUNT(*) AS total_papers
|
369 |
+
FROM
|
370 |
+
paper_modalities
|
371 |
+
|
372 |
+
UNION ALL
|
373 |
+
|
374 |
+
SELECT
|
375 |
+
'Images, Spectra, Modality 1, 2, 3, and 4' AS modalities,
|
376 |
+
4 AS added_modality_rank,
|
377 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 OR uses_modality_4 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
378 |
+
COUNT(*) AS total_papers
|
379 |
+
FROM
|
380 |
+
paper_modalities
|
381 |
+
|
382 |
+
UNION ALL
|
383 |
+
|
384 |
+
SELECT
|
385 |
+
'Images, Spectra, Modality 1, 2, 3, 4, and 5' AS modalities,
|
386 |
+
5 AS added_modality_rank,
|
387 |
+
SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 OR uses_modality_4 = 1 OR uses_modality_5 = 1 THEN 1 ELSE 0 END) AS covered_papers,
|
388 |
+
COUNT(*) AS total_papers
|
389 |
+
FROM
|
390 |
+
paper_modalities
|
391 |
+
)
|
392 |
+
SELECT
|
393 |
+
cc.modalities,
|
394 |
+
COALESCE(rm.modality, 'N/A') AS added_modality,
|
395 |
+
rm.usage_count AS added_modality_usage,
|
396 |
+
ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2) AS coverage_percent,
|
397 |
+
ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2) -
|
398 |
+
LAG(ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2), 1, 0) OVER (ORDER BY cc.added_modality_rank) AS coverage_increase_percent
|
399 |
+
FROM
|
400 |
+
cumulative_coverage cc
|
401 |
+
LEFT JOIN
|
402 |
+
ranked_modalities rm ON cc.added_modality_rank = rm.rank
|
403 |
+
ORDER BY
|
404 |
+
cc.added_modality_rank
|
405 |
+
""",
|
406 |
+
),
|
407 |
+
(
|
408 |
+
"List all papers",
|
409 |
+
"SELECT paper_id, abstract AS abstract_preview, authors, primary_category FROM papers",
|
410 |
+
),
|
411 |
+
(
|
412 |
+
"Count papers by category",
|
413 |
+
"SELECT primary_category, COUNT(*) as paper_count FROM papers GROUP BY primary_category ORDER BY paper_count DESC",
|
414 |
+
),
|
415 |
+
(
|
416 |
+
"Top authors with most papers",
|
417 |
+
"""
|
418 |
+
WITH author_papers AS (
|
419 |
+
SELECT json_each.value AS author
|
420 |
+
FROM papers, json_each(papers.authors)
|
421 |
+
)
|
422 |
+
SELECT author, COUNT(*) as paper_count
|
423 |
+
FROM author_papers
|
424 |
+
GROUP BY author
|
425 |
+
ORDER BY paper_count DESC
|
426 |
+
""",
|
427 |
+
),
|
428 |
+
(
|
429 |
+
"Papers with 'quantum' in abstract",
|
430 |
+
"SELECT paper_id, abstract AS abstract_preview FROM papers WHERE abstract LIKE '%quantum%'",
|
431 |
+
),
|
432 |
+
(
|
433 |
+
"Most common concepts",
|
434 |
+
"SELECT concept, COUNT(*) as concept_count FROM predictions GROUP BY concept ORDER BY concept_count DESC",
|
435 |
+
),
|
436 |
+
(
|
437 |
+
"Papers with multiple authors",
|
438 |
+
"""
|
439 |
+
SELECT paper_id, json_array_length(authors) as author_count, authors
|
440 |
+
FROM papers
|
441 |
+
WHERE json_array_length(authors) > 1
|
442 |
+
ORDER BY author_count DESC
|
443 |
+
""",
|
444 |
+
),
|
445 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.40.0
|
2 |
+
networkx==3.3
|
3 |
+
pandas==2.2.2
|
4 |
+
plotly==5.23.0
|
5 |
+
tabulate==0.9.0
|
6 |
+
fastapi==0.104.1
|
7 |
+
pydantic==2.5.3
|
8 |
+
uvicorn==0.27.1
|
scripts/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
if ROOT_DIR not in sys.path:
|
6 |
+
sys.path.insert(0, ROOT_DIR)
|
scripts/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (566 Bytes). View file
|
|
scripts/__pycache__/create_db.cpython-311.pyc
ADDED
Binary file (14.6 kB). View file
|
|
scripts/__pycache__/run_db_interface.cpython-311.pyc
ADDED
Binary file (29 kB). View file
|
|
scripts/__pycache__/run_db_interface_improved.cpython-311.pyc
ADDED
Binary file (29.2 kB). View file
|
|
scripts/create_db.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import click
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sqlite3
|
5 |
+
import sys
|
6 |
+
|
7 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
8 |
+
from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID
|
9 |
+
from src.processing.generate import get_sentences, generate_prediction
|
10 |
+
from src.utils.utils import load_model_and_tokenizer
|
11 |
+
|
12 |
+
|
13 |
+
class ArxivDatabase:
|
14 |
+
def __init__(self, db_path, model_id=None):
|
15 |
+
self.conn = None
|
16 |
+
self.cursor = None
|
17 |
+
self.db_path = db_path
|
18 |
+
self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID
|
19 |
+
self.model = None
|
20 |
+
self.tokenizer = None
|
21 |
+
self.is_db_empty = True
|
22 |
+
self.paper_table = """CREATE TABLE IF NOT EXISTS papers
|
23 |
+
(paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT,
|
24 |
+
primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)"""
|
25 |
+
self.pred_table = """CREATE TABLE IF NOT EXISTS predictions
|
26 |
+
(id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER,
|
27 |
+
tag_type TEXT, concept TEXT,
|
28 |
+
FOREIGN KEY (paper_id) REFERENCES papers(paper_id))"""
|
29 |
+
|
30 |
+
# def init_db(self):
|
31 |
+
# self.cursor.execute(self.paper_table)
|
32 |
+
# self.cursor.execute(self.pred_table)
|
33 |
+
|
34 |
+
# print("Database and tables created successfully.")
|
35 |
+
# self.is_db_empty = self.is_empty()
|
36 |
+
|
37 |
+
def init_db(self):
|
38 |
+
self.conn = sqlite3.connect(self.db_path)
|
39 |
+
self.cursor = self.conn.cursor()
|
40 |
+
self.cursor.execute(self.paper_table)
|
41 |
+
self.cursor.execute(self.pred_table)
|
42 |
+
self.conn.commit()
|
43 |
+
self.is_db_empty = self.is_empty()
|
44 |
+
if not self.is_db_empty:
|
45 |
+
print("Database already contains data.")
|
46 |
+
else:
|
47 |
+
print("Database and tables created successfully.")
|
48 |
+
|
49 |
+
def is_empty(self):
|
50 |
+
try:
|
51 |
+
self.cursor.execute("SELECT COUNT(*) FROM papers")
|
52 |
+
count = self.cursor.fetchone()[0]
|
53 |
+
return count == 0
|
54 |
+
except sqlite3.OperationalError:
|
55 |
+
return True
|
56 |
+
|
57 |
+
def get_connection(self):
|
58 |
+
return sqlite3.connect(self.conn.path)
|
59 |
+
|
60 |
+
def populate_db(self, data_path, pred_path):
|
61 |
+
papers_info = self._insert_papers(data_path)
|
62 |
+
self._insert_predictions(pred_path, papers_info)
|
63 |
+
print("Database population completed.")
|
64 |
+
|
65 |
+
def _insert_papers(self, data_path):
|
66 |
+
papers_info = []
|
67 |
+
seen_papers = set()
|
68 |
+
with open(data_path, "r") as f:
|
69 |
+
for line in f:
|
70 |
+
paper = json.loads(line)
|
71 |
+
if paper["id"] in seen_papers:
|
72 |
+
continue
|
73 |
+
seen_papers.add(paper["id"])
|
74 |
+
sentence_count = len(get_sentences(paper["id"])) + len(
|
75 |
+
get_sentences(paper["abstract"])
|
76 |
+
)
|
77 |
+
papers_info.append((paper["id"], sentence_count))
|
78 |
+
self.cursor.execute(
|
79 |
+
"""INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
80 |
+
(
|
81 |
+
paper["id"],
|
82 |
+
paper["abstract"],
|
83 |
+
json.dumps(paper["authors"]),
|
84 |
+
json.dumps(paper["primary_category"]),
|
85 |
+
json.dumps(paper["url"]),
|
86 |
+
json.dumps(paper["updated"]),
|
87 |
+
sentence_count,
|
88 |
+
),
|
89 |
+
)
|
90 |
+
print(f"Inserted {len(papers_info)} papers.")
|
91 |
+
return papers_info
|
92 |
+
|
93 |
+
def _insert_predictions(self, pred_path, papers_info):
|
94 |
+
with open(pred_path, "r") as f:
|
95 |
+
predictions = json.load(f)
|
96 |
+
predicted_tags = predictions["predicted_tags"]
|
97 |
+
|
98 |
+
k = 0
|
99 |
+
papers_with_predictions = set()
|
100 |
+
papers_without_predictions = []
|
101 |
+
for paper_id, sentence_count in papers_info:
|
102 |
+
paper_predictions = predicted_tags[k : k + sentence_count]
|
103 |
+
|
104 |
+
has_predictions = False
|
105 |
+
for sentence_index, pred in enumerate(paper_predictions):
|
106 |
+
if pred: # If the prediction is not an empty dictionary
|
107 |
+
has_predictions = True
|
108 |
+
for tag_type, concepts in pred.items():
|
109 |
+
for concept in concepts:
|
110 |
+
self.cursor.execute(
|
111 |
+
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
|
112 |
+
VALUES (?, ?, ?, ?)""",
|
113 |
+
(paper_id, sentence_index, tag_type, concept),
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
# Insert a null prediction to ensure the paper is counted
|
117 |
+
self.cursor.execute(
|
118 |
+
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
|
119 |
+
VALUES (?, ?, ?, ?)""",
|
120 |
+
(paper_id, sentence_index, "null", "null"),
|
121 |
+
)
|
122 |
+
|
123 |
+
if has_predictions:
|
124 |
+
papers_with_predictions.add(paper_id)
|
125 |
+
else:
|
126 |
+
papers_without_predictions.append(paper_id)
|
127 |
+
|
128 |
+
k += sentence_count
|
129 |
+
|
130 |
+
print(f"Inserted predictions for {len(papers_with_predictions)} papers.")
|
131 |
+
print(f"Papers without any predictions: {len(papers_without_predictions)}")
|
132 |
+
|
133 |
+
if k < len(predicted_tags):
|
134 |
+
print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.")
|
135 |
+
|
136 |
+
def load_model(self):
|
137 |
+
if self.model is None:
|
138 |
+
try:
|
139 |
+
self.model, self.tokenizer = load_model_and_tokenizer(self.model_id)
|
140 |
+
return f"Model {self.model_id} loaded successfully."
|
141 |
+
except Exception as e:
|
142 |
+
return f"Error loading model: {str(e)}"
|
143 |
+
else:
|
144 |
+
return "Model is already loaded."
|
145 |
+
|
146 |
+
def natural_language_to_sql(self, question):
|
147 |
+
system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers."
|
148 |
+
table = self.paper_table + "; " + self.pred_table
|
149 |
+
prefix = (
|
150 |
+
f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using "
|
151 |
+
f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` "
|
152 |
+
)
|
153 |
+
|
154 |
+
sql_query = generate_prediction(
|
155 |
+
self.model, self.tokenizer, prefix, question, "sql", system_prompt
|
156 |
+
)
|
157 |
+
|
158 |
+
sql_query = sql_query.split("```")[1]
|
159 |
+
|
160 |
+
return sql_query
|
161 |
+
|
162 |
+
def execute_query(self, sql_query):
|
163 |
+
try:
|
164 |
+
self.cursor.execute(sql_query)
|
165 |
+
results = self.cursor.fetchall()
|
166 |
+
return results if results else []
|
167 |
+
except sqlite3.Error as e:
|
168 |
+
return [(f"An error occurred: {e}",)]
|
169 |
+
|
170 |
+
def query_db(self, question, is_sql):
|
171 |
+
if self.is_db_empty:
|
172 |
+
return "The database is empty. Please populate it with data first."
|
173 |
+
|
174 |
+
try:
|
175 |
+
if is_sql:
|
176 |
+
sql_query = question.strip()
|
177 |
+
else:
|
178 |
+
nl_to_sql = self.natural_language_to_sql(question)
|
179 |
+
sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip()
|
180 |
+
|
181 |
+
results = self.execute_query(sql_query)
|
182 |
+
|
183 |
+
output = f"SQL Query: {sql_query}\n\nResults:\n"
|
184 |
+
if isinstance(results, list):
|
185 |
+
if len(results) > 0:
|
186 |
+
for row in results:
|
187 |
+
output += str(row) + "\n"
|
188 |
+
else:
|
189 |
+
output += "No results found."
|
190 |
+
else:
|
191 |
+
output += str(results) # In case of an error message
|
192 |
+
|
193 |
+
return output
|
194 |
+
except Exception as e:
|
195 |
+
return f"An error occurred: {str(e)}"
|
196 |
+
|
197 |
+
def close(self):
|
198 |
+
self.conn.commit()
|
199 |
+
self.conn.close()
|
200 |
+
|
201 |
+
|
202 |
+
def check_db_exists(db_path):
|
203 |
+
return os.path.exists(db_path) and os.path.getsize(db_path) > 0
|
204 |
+
|
205 |
+
|
206 |
+
@click.command()
|
207 |
+
@click.option(
|
208 |
+
"--data_path", help="Path to the data file containing the papers information."
|
209 |
+
)
|
210 |
+
@click.option("--pred_path", help="Path to the predictions file.")
|
211 |
+
@click.option("--db_name", default="arxiv.db", help="Name of the database to create.")
|
212 |
+
@click.option(
|
213 |
+
"--force", is_flag=True, help="Force overwrite if database already exists"
|
214 |
+
)
|
215 |
+
def main(data_path, pred_path, db_name, force):
|
216 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
217 |
+
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
|
218 |
+
os.makedirs(tables_dir, exist_ok=True)
|
219 |
+
db_path = os.path.join(tables_dir, db_name)
|
220 |
+
|
221 |
+
db_exists = check_db_exists(db_path)
|
222 |
+
|
223 |
+
db = ArxivDatabase(db_path)
|
224 |
+
db.init_db()
|
225 |
+
|
226 |
+
if db_exists and not db.is_db_empty:
|
227 |
+
if not force:
|
228 |
+
print(f"Warning: The database '{db_name}' already exists and is not empty.")
|
229 |
+
overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip()
|
230 |
+
if overwrite != "y":
|
231 |
+
print("Operation cancelled.")
|
232 |
+
db.close()
|
233 |
+
return
|
234 |
+
else:
|
235 |
+
print(
|
236 |
+
f"Warning: Overwriting existing database '{db_name}' due to --force flag."
|
237 |
+
)
|
238 |
+
|
239 |
+
db.populate_db(data_path, pred_path)
|
240 |
+
db.close()
|
241 |
+
|
242 |
+
print(f"Database created and populated at: {db_path}")
|
243 |
+
|
244 |
+
|
245 |
+
if __name__ == "__main__":
|
246 |
+
main()
|
scripts/run_db_interface.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import networkx as nx
|
5 |
+
import pandas as pd
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
+
import sqlite3
|
10 |
+
import tempfile
|
11 |
+
import time
|
12 |
+
import uvicorn
|
13 |
+
|
14 |
+
from contextlib import contextmanager
|
15 |
+
from fastapi import FastAPI, Request
|
16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
17 |
+
from gradio.routes import mount_gradio_app
|
18 |
+
from plotly.subplots import make_subplots
|
19 |
+
from tabulate import tabulate
|
20 |
+
from typing import Optional
|
21 |
+
|
22 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
23 |
+
if ROOT_DIR not in sys.path:
|
24 |
+
sys.path.insert(0, ROOT_DIR)
|
25 |
+
|
26 |
+
from scripts.create_db import ArxivDatabase
|
27 |
+
from config import (
|
28 |
+
DEFAULT_TABLES_DIR,
|
29 |
+
DEFAULT_INTERFACE_MODEL_ID,
|
30 |
+
COOCCURRENCE_QUERY,
|
31 |
+
canned_queries,
|
32 |
+
)
|
33 |
+
|
34 |
+
app = FastAPI()
|
35 |
+
|
36 |
+
# Add CORS middleware
|
37 |
+
app.add_middleware(
|
38 |
+
CORSMiddleware,
|
39 |
+
allow_origins=["*"],
|
40 |
+
allow_credentials=True,
|
41 |
+
allow_methods=["*"],
|
42 |
+
allow_headers=["*"],
|
43 |
+
)
|
44 |
+
|
45 |
+
db: Optional[ArxivDatabase] = None
|
46 |
+
|
47 |
+
last_update_time = 0
|
48 |
+
update_delay = 0.5 # Delay in seconds
|
49 |
+
|
50 |
+
|
51 |
+
def truncate_or_wrap_text(text, max_length=50, wrap=False):
|
52 |
+
"""Truncate text to a maximum length, adding ellipsis if truncated, or wrap if specified."""
|
53 |
+
if wrap:
|
54 |
+
return "\n".join(
|
55 |
+
text[i : i + max_length] for i in range(0, len(text), max_length)
|
56 |
+
)
|
57 |
+
return text[:max_length] + "..." if len(text) > max_length else text
|
58 |
+
|
59 |
+
|
60 |
+
def format_url(url):
|
61 |
+
"""Format URL to be more compact in the table."""
|
62 |
+
return url.split("/")[-1] if url.startswith("http") else url
|
63 |
+
|
64 |
+
|
65 |
+
def get_db_path():
|
66 |
+
"""Get the database directory path based on environment"""
|
67 |
+
# First try local path
|
68 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
69 |
+
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
|
70 |
+
|
71 |
+
if not os.path.exists(tables_dir):
|
72 |
+
# If running on Spaces, try the root directory
|
73 |
+
tables_dir = os.path.join(ROOT, "data", "databases")
|
74 |
+
if not os.path.exists(tables_dir):
|
75 |
+
print(f"No database directory found")
|
76 |
+
return None
|
77 |
+
|
78 |
+
print(f"Using database directory: {tables_dir}")
|
79 |
+
return tables_dir
|
80 |
+
|
81 |
+
|
82 |
+
def get_available_databases():
|
83 |
+
"""Get available databases from either local path or Hugging Face cache."""
|
84 |
+
tables_dir = get_db_path()
|
85 |
+
if not tables_dir:
|
86 |
+
return []
|
87 |
+
|
88 |
+
files = os.listdir(tables_dir)
|
89 |
+
print(f"All files found: {files}")
|
90 |
+
|
91 |
+
# Include all files except .md files
|
92 |
+
databases = [f for f in files if not f.endswith(".md")]
|
93 |
+
print(f"Database files: {databases}")
|
94 |
+
|
95 |
+
return databases
|
96 |
+
|
97 |
+
|
98 |
+
def query_db(query, is_sql, limit=None, wrap=False):
|
99 |
+
global db
|
100 |
+
if db is None:
|
101 |
+
return pd.DataFrame({"Error": ["Please load a database first."]})
|
102 |
+
|
103 |
+
try:
|
104 |
+
with sqlite3.connect(db.db_path) as conn:
|
105 |
+
cursor = conn.cursor()
|
106 |
+
|
107 |
+
query = " ".join(query.strip().split("\n")).rstrip(";")
|
108 |
+
|
109 |
+
if limit is not None:
|
110 |
+
if "LIMIT" in query.upper():
|
111 |
+
# Replace existing LIMIT clause
|
112 |
+
query = re.sub(
|
113 |
+
r"LIMIT\s+\d+", f"LIMIT {limit}", query, flags=re.IGNORECASE
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
query += f" LIMIT {limit}"
|
117 |
+
|
118 |
+
cursor.execute(query)
|
119 |
+
|
120 |
+
column_names = [description[0] for description in cursor.description]
|
121 |
+
|
122 |
+
results = cursor.fetchall()
|
123 |
+
|
124 |
+
df = pd.DataFrame(results, columns=column_names)
|
125 |
+
|
126 |
+
for column in df.columns:
|
127 |
+
if df[column].dtype == "object":
|
128 |
+
df[column] = df[column].apply(
|
129 |
+
lambda x: (
|
130 |
+
format_url(x)
|
131 |
+
if column == "url"
|
132 |
+
else truncate_or_wrap_text(x, wrap=wrap)
|
133 |
+
)
|
134 |
+
)
|
135 |
+
|
136 |
+
return df
|
137 |
+
|
138 |
+
except sqlite3.Error as e:
|
139 |
+
return pd.DataFrame({"Error": [f"Database error: {str(e)}"]})
|
140 |
+
except Exception as e:
|
141 |
+
return pd.DataFrame({"Error": [f"An unexpected error occurred: {str(e)}"]})
|
142 |
+
|
143 |
+
|
144 |
+
def generate_concept_cooccurrence_graph(db_path, tag_type=None):
|
145 |
+
conn = sqlite3.connect(db_path)
|
146 |
+
|
147 |
+
query = COOCCURRENCE_QUERY
|
148 |
+
if tag_type and tag_type != "All":
|
149 |
+
query = query.replace(
|
150 |
+
"WHERE p1.tag_type = p2.tag_type",
|
151 |
+
f"WHERE p1.tag_type = p2.tag_type AND p1.tag_type = '{tag_type}'",
|
152 |
+
)
|
153 |
+
|
154 |
+
df = pd.read_sql_query(query, conn)
|
155 |
+
conn.close()
|
156 |
+
|
157 |
+
G = nx.from_pandas_edgelist(df, "concept1", "concept2", "co_occurrences")
|
158 |
+
pos = nx.spring_layout(G, k=0.5, iterations=50)
|
159 |
+
|
160 |
+
edge_trace = go.Scatter(
|
161 |
+
x=[], y=[], line=dict(width=0.5, color="#888"), hoverinfo="none", mode="lines"
|
162 |
+
)
|
163 |
+
|
164 |
+
node_trace = go.Scatter(
|
165 |
+
x=[],
|
166 |
+
y=[],
|
167 |
+
mode="markers",
|
168 |
+
hoverinfo="text",
|
169 |
+
marker=dict(
|
170 |
+
showscale=True,
|
171 |
+
colorscale="YlGnBu",
|
172 |
+
size=10,
|
173 |
+
colorbar=dict(
|
174 |
+
thickness=15,
|
175 |
+
title="Node Connections",
|
176 |
+
xanchor="left",
|
177 |
+
titleside="right",
|
178 |
+
),
|
179 |
+
),
|
180 |
+
)
|
181 |
+
|
182 |
+
def update_traces(selected_node=None, depth=0):
|
183 |
+
nonlocal edge_trace, node_trace
|
184 |
+
|
185 |
+
if selected_node and depth > 0:
|
186 |
+
nodes_to_show = set([selected_node])
|
187 |
+
frontier = set([selected_node])
|
188 |
+
for _ in range(depth):
|
189 |
+
new_frontier = set()
|
190 |
+
for node in frontier:
|
191 |
+
new_frontier.update(G.neighbors(node))
|
192 |
+
nodes_to_show.update(new_frontier)
|
193 |
+
frontier = new_frontier
|
194 |
+
sub_G = G.subgraph(nodes_to_show)
|
195 |
+
else:
|
196 |
+
sub_G = G
|
197 |
+
|
198 |
+
edge_x, edge_y = [], []
|
199 |
+
for edge in sub_G.edges():
|
200 |
+
x0, y0 = pos[edge[0]]
|
201 |
+
x1, y1 = pos[edge[1]]
|
202 |
+
edge_x.extend([x0, x1, None])
|
203 |
+
edge_y.extend([y0, y1, None])
|
204 |
+
|
205 |
+
edge_trace.x = edge_x
|
206 |
+
edge_trace.y = edge_y
|
207 |
+
|
208 |
+
node_x, node_y = [], []
|
209 |
+
for node in sub_G.nodes():
|
210 |
+
x, y = pos[node]
|
211 |
+
node_x.append(x)
|
212 |
+
node_y.append(y)
|
213 |
+
|
214 |
+
node_trace.x = node_x
|
215 |
+
node_trace.y = node_y
|
216 |
+
|
217 |
+
node_adjacencies = []
|
218 |
+
node_text = []
|
219 |
+
for node in sub_G.nodes():
|
220 |
+
adjacencies = list(G.adj[node])
|
221 |
+
node_adjacencies.append(len(adjacencies))
|
222 |
+
node_text.append(f"{node}<br># of connections: {len(adjacencies)}")
|
223 |
+
|
224 |
+
node_trace.marker.color = node_adjacencies
|
225 |
+
node_trace.text = node_text
|
226 |
+
|
227 |
+
update_traces()
|
228 |
+
|
229 |
+
fig = go.Figure(
|
230 |
+
data=[edge_trace, node_trace],
|
231 |
+
layout=go.Layout(
|
232 |
+
title=f'Concept Co-occurrence Network {f"({tag_type})" if tag_type and tag_type != "All" else ""}',
|
233 |
+
titlefont_size=16,
|
234 |
+
showlegend=False,
|
235 |
+
hovermode="closest",
|
236 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
237 |
+
annotations=[
|
238 |
+
dict(
|
239 |
+
text="",
|
240 |
+
showarrow=False,
|
241 |
+
xref="paper",
|
242 |
+
yref="paper",
|
243 |
+
x=0.005,
|
244 |
+
y=-0.002,
|
245 |
+
)
|
246 |
+
],
|
247 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
248 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
249 |
+
),
|
250 |
+
)
|
251 |
+
|
252 |
+
fig.update_layout(
|
253 |
+
updatemenus=[
|
254 |
+
dict(
|
255 |
+
type="buttons",
|
256 |
+
direction="left",
|
257 |
+
buttons=[
|
258 |
+
dict(
|
259 |
+
args=[{"visible": [True, True]}],
|
260 |
+
label="Full Graph",
|
261 |
+
method="update",
|
262 |
+
),
|
263 |
+
dict(
|
264 |
+
args=[
|
265 |
+
{
|
266 |
+
"visible": [True, True],
|
267 |
+
"xaxis.range": [-1, 1],
|
268 |
+
"yaxis.range": [-1, 1],
|
269 |
+
}
|
270 |
+
],
|
271 |
+
label="Core View",
|
272 |
+
method="relayout",
|
273 |
+
),
|
274 |
+
dict(
|
275 |
+
args=[
|
276 |
+
{
|
277 |
+
"visible": [True, True],
|
278 |
+
"xaxis.range": [-0.2, 0.2],
|
279 |
+
"yaxis.range": [-0.2, 0.2],
|
280 |
+
}
|
281 |
+
],
|
282 |
+
label="Detailed View",
|
283 |
+
method="relayout",
|
284 |
+
),
|
285 |
+
],
|
286 |
+
pad={"r": 10, "t": 10},
|
287 |
+
showactive=True,
|
288 |
+
x=0.11,
|
289 |
+
xanchor="left",
|
290 |
+
y=1.1,
|
291 |
+
yanchor="top",
|
292 |
+
),
|
293 |
+
]
|
294 |
+
)
|
295 |
+
|
296 |
+
return fig, G, pos, update_traces
|
297 |
+
|
298 |
+
|
299 |
+
def load_database_with_graphs(db_name):
|
300 |
+
"""Load database from either local path or Hugging Face cache."""
|
301 |
+
global db
|
302 |
+
tables_dir = get_db_path()
|
303 |
+
if not tables_dir:
|
304 |
+
return f"No database directory found.", None
|
305 |
+
|
306 |
+
db_path = os.path.join(tables_dir, db_name)
|
307 |
+
if not os.path.exists(db_path):
|
308 |
+
return f"Database {db_name} does not exist.", None
|
309 |
+
|
310 |
+
db = ArxivDatabase(db_path)
|
311 |
+
db.init_db()
|
312 |
+
|
313 |
+
if db.is_db_empty:
|
314 |
+
return (
|
315 |
+
f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
|
316 |
+
None,
|
317 |
+
)
|
318 |
+
|
319 |
+
graph, _, _, _ = generate_concept_cooccurrence_graph(db_path)
|
320 |
+
return f"Database loaded from {db_path}", graph
|
321 |
+
|
322 |
+
|
323 |
+
css = """
|
324 |
+
#selected-query {
|
325 |
+
max-height: 100px;
|
326 |
+
overflow-y: auto;
|
327 |
+
white-space: pre-wrap;
|
328 |
+
word-break: break-word;
|
329 |
+
}
|
330 |
+
"""
|
331 |
+
|
332 |
+
|
333 |
+
def create_demo():
|
334 |
+
with gr.Blocks() as demo:
|
335 |
+
gr.Markdown("# ArXiv Database Query Interface")
|
336 |
+
|
337 |
+
with gr.Row():
|
338 |
+
db_dropdown = gr.Dropdown(
|
339 |
+
choices=get_available_databases(),
|
340 |
+
label="Select Database",
|
341 |
+
value=get_available_databases(),
|
342 |
+
)
|
343 |
+
# load_db_btn = gr.Button("Load Database", size="sm")
|
344 |
+
status = gr.Textbox(label="Status")
|
345 |
+
|
346 |
+
with gr.Row():
|
347 |
+
graph_output = gr.Plot(label="Concept Co-occurrence Graph")
|
348 |
+
|
349 |
+
with gr.Row():
|
350 |
+
tag_type_dropdown = gr.Dropdown(
|
351 |
+
choices=[
|
352 |
+
"All",
|
353 |
+
"model",
|
354 |
+
"task",
|
355 |
+
"dataset",
|
356 |
+
"field",
|
357 |
+
"modality",
|
358 |
+
"method",
|
359 |
+
"object",
|
360 |
+
"property",
|
361 |
+
"instrument",
|
362 |
+
],
|
363 |
+
label="Select Tag Type",
|
364 |
+
value="All",
|
365 |
+
)
|
366 |
+
highlight_input = gr.Textbox(label="Highlight Concepts (comma-separated)")
|
367 |
+
|
368 |
+
with gr.Row():
|
369 |
+
node_dropdown = gr.Dropdown(label="Select Node", choices=[])
|
370 |
+
depth_slider = gr.Slider(
|
371 |
+
minimum=0, maximum=5, step=1, value=0, label="Connection Depth"
|
372 |
+
)
|
373 |
+
update_graph_button = gr.Button("Update Graph")
|
374 |
+
|
375 |
+
with gr.Row():
|
376 |
+
wrap_checkbox = gr.Checkbox(label="Wrap long text", value=False)
|
377 |
+
canned_query_dropdown = gr.Dropdown(
|
378 |
+
choices=[q[0] for q in canned_queries], label="Select Query", scale=3
|
379 |
+
)
|
380 |
+
limit_input = gr.Number(
|
381 |
+
label="Limit", value=10000, step=1, minimum=1, scale=1
|
382 |
+
)
|
383 |
+
selected_query = gr.Textbox(
|
384 |
+
label="Selected Query",
|
385 |
+
interactive=False,
|
386 |
+
scale=2,
|
387 |
+
show_label=True,
|
388 |
+
show_copy_button=True,
|
389 |
+
elem_id="selected-query",
|
390 |
+
)
|
391 |
+
canned_query_submit = gr.Button("Submit Query", size="sm", scale=1)
|
392 |
+
|
393 |
+
with gr.Row():
|
394 |
+
sql_input = gr.Textbox(label="Custom SQL Query", lines=3, scale=4)
|
395 |
+
sql_submit = gr.Button("Submit Custom SQL", size="sm", scale=1)
|
396 |
+
|
397 |
+
# with gr.Row():
|
398 |
+
# nl_query_input = gr.Textbox(
|
399 |
+
# label="Natural Language Query", lines=2, scale=4
|
400 |
+
# )
|
401 |
+
# nl_query_submit = gr.Button("Convert to SQL", size="sm", scale=1)
|
402 |
+
|
403 |
+
output = gr.DataFrame(label="Results", wrap=True)
|
404 |
+
|
405 |
+
with gr.Row():
|
406 |
+
copy_button = gr.Button("Copy as Markdown")
|
407 |
+
download_button = gr.Button("Download as CSV")
|
408 |
+
|
409 |
+
def debounced_update_graph(
|
410 |
+
db_name, tag_type, highlight_concepts, selected_node, depth
|
411 |
+
):
|
412 |
+
global last_update_time
|
413 |
+
|
414 |
+
current_time = time.time()
|
415 |
+
if current_time - last_update_time < update_delay:
|
416 |
+
return None, [] # Return early if not enough time has passed
|
417 |
+
|
418 |
+
last_update_time = current_time
|
419 |
+
|
420 |
+
if not db_name:
|
421 |
+
return None, []
|
422 |
+
|
423 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
424 |
+
db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
|
425 |
+
fig, G, pos, update_traces = generate_concept_cooccurrence_graph(
|
426 |
+
db_path, tag_type
|
427 |
+
)
|
428 |
+
|
429 |
+
if isinstance(selected_node, list):
|
430 |
+
selected_node = selected_node[0] if selected_node else None
|
431 |
+
|
432 |
+
highlight_nodes = (
|
433 |
+
[node.strip() for node in highlight_concepts.split(",")]
|
434 |
+
if highlight_concepts
|
435 |
+
else []
|
436 |
+
)
|
437 |
+
primary_node = highlight_nodes[0] if highlight_nodes else None
|
438 |
+
|
439 |
+
if primary_node and primary_node in G.nodes():
|
440 |
+
# Apply node selection and depth filter
|
441 |
+
nodes_to_show = set([primary_node])
|
442 |
+
if depth > 0:
|
443 |
+
frontier = set([primary_node])
|
444 |
+
for _ in range(depth):
|
445 |
+
new_frontier = set()
|
446 |
+
for node in frontier:
|
447 |
+
new_frontier.update(G.neighbors(node))
|
448 |
+
nodes_to_show.update(new_frontier)
|
449 |
+
frontier = new_frontier
|
450 |
+
|
451 |
+
sub_G = G.subgraph(nodes_to_show)
|
452 |
+
|
453 |
+
# Update traces with the filtered graph
|
454 |
+
edge_x, edge_y = [], []
|
455 |
+
for edge in sub_G.edges():
|
456 |
+
x0, y0 = pos[edge[0]]
|
457 |
+
x1, y1 = pos[edge[1]]
|
458 |
+
edge_x.extend([x0, x1, None])
|
459 |
+
edge_y.extend([y0, y1, None])
|
460 |
+
|
461 |
+
fig.data[0].x = edge_x
|
462 |
+
fig.data[0].y = edge_y
|
463 |
+
|
464 |
+
node_x, node_y = [], []
|
465 |
+
for node in sub_G.nodes():
|
466 |
+
x, y = pos[node]
|
467 |
+
node_x.append(x)
|
468 |
+
node_y.append(y)
|
469 |
+
|
470 |
+
fig.data[1].x = node_x
|
471 |
+
fig.data[1].y = node_y
|
472 |
+
|
473 |
+
# Color nodes based on their distance from the primary node and highlight status
|
474 |
+
node_colors = []
|
475 |
+
node_sizes = []
|
476 |
+
for node in sub_G.nodes():
|
477 |
+
if node in highlight_nodes:
|
478 |
+
node_colors.append(
|
479 |
+
"rgba(255,0,0,1)"
|
480 |
+
) # Red for highlighted nodes
|
481 |
+
node_sizes.append(15)
|
482 |
+
else:
|
483 |
+
distance = nx.shortest_path_length(
|
484 |
+
sub_G, source=primary_node, target=node
|
485 |
+
)
|
486 |
+
intensity = max(0, 1 - (distance / (depth + 1)))
|
487 |
+
node_colors.append(f"rgba(0,0,255,{intensity})")
|
488 |
+
node_sizes.append(10)
|
489 |
+
|
490 |
+
fig.data[1].marker.color = node_colors
|
491 |
+
fig.data[1].marker.size = node_sizes
|
492 |
+
|
493 |
+
# Update node text
|
494 |
+
node_text = [
|
495 |
+
f"{node}<br># of connections: {len(list(G.neighbors(node)))}"
|
496 |
+
for node in sub_G.nodes()
|
497 |
+
]
|
498 |
+
fig.data[1].text = node_text
|
499 |
+
|
500 |
+
# Get connected nodes for dropdown
|
501 |
+
connected_nodes = sorted(list(G.neighbors(primary_node)))
|
502 |
+
else:
|
503 |
+
# If no primary node or it's not in the graph, show the full graph
|
504 |
+
connected_nodes = sorted(list(G.nodes()))
|
505 |
+
|
506 |
+
return fig, connected_nodes
|
507 |
+
|
508 |
+
def update_node_dropdown(highlight_concepts):
|
509 |
+
if not highlight_concepts or not db:
|
510 |
+
return gr.Dropdown(choices=[])
|
511 |
+
|
512 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
513 |
+
db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db.db_path)
|
514 |
+
_, G, _, _ = generate_concept_cooccurrence_graph(db_path)
|
515 |
+
|
516 |
+
primary_node = highlight_concepts.split(",")[0].strip()
|
517 |
+
if primary_node in G.nodes():
|
518 |
+
connected_nodes = sorted(list(G.neighbors(primary_node)))
|
519 |
+
return gr.Dropdown(choices=connected_nodes)
|
520 |
+
else:
|
521 |
+
return gr.Dropdown(choices=[])
|
522 |
+
|
523 |
+
def update_selected_query(query_description):
|
524 |
+
for desc, sql in canned_queries:
|
525 |
+
if desc == query_description:
|
526 |
+
return sql
|
527 |
+
return ""
|
528 |
+
|
529 |
+
def submit_canned_query(query_description, limit, wrap):
|
530 |
+
for desc, sql in canned_queries:
|
531 |
+
if desc == query_description:
|
532 |
+
return query_db(sql, True, limit, wrap)
|
533 |
+
return pd.DataFrame({"Error": ["Selected query not found."]})
|
534 |
+
|
535 |
+
def copy_as_markdown(df):
|
536 |
+
return df.to_markdown()
|
537 |
+
|
538 |
+
def download_as_csv(df):
|
539 |
+
if df is None or df.empty:
|
540 |
+
return None
|
541 |
+
|
542 |
+
with tempfile.NamedTemporaryFile(
|
543 |
+
mode="w", delete=False, suffix=".csv"
|
544 |
+
) as temp_file:
|
545 |
+
df.to_csv(temp_file.name, index=False)
|
546 |
+
temp_file_path = temp_file.name
|
547 |
+
|
548 |
+
return temp_file_path
|
549 |
+
|
550 |
+
# def nl_to_sql(nl_query):
|
551 |
+
# # Placeholder function for natural language to SQL conversion
|
552 |
+
# return f"SELECT * FROM papers WHERE abstract LIKE '%{nl_query}%' LIMIT 10;"
|
553 |
+
|
554 |
+
db_dropdown.change(
|
555 |
+
load_database_with_graphs,
|
556 |
+
inputs=[db_dropdown],
|
557 |
+
outputs=[status, graph_output],
|
558 |
+
)
|
559 |
+
|
560 |
+
# db_dropdown.change(
|
561 |
+
# debounced_update_graph,
|
562 |
+
# inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
|
563 |
+
# outputs=[graph_output, node_dropdown],
|
564 |
+
# )
|
565 |
+
|
566 |
+
tag_type_dropdown.change(
|
567 |
+
debounced_update_graph,
|
568 |
+
inputs=[
|
569 |
+
db_dropdown,
|
570 |
+
tag_type_dropdown,
|
571 |
+
highlight_input,
|
572 |
+
node_dropdown,
|
573 |
+
depth_slider,
|
574 |
+
],
|
575 |
+
outputs=[graph_output, node_dropdown],
|
576 |
+
)
|
577 |
+
|
578 |
+
highlight_input.change(
|
579 |
+
update_node_dropdown,
|
580 |
+
inputs=[highlight_input],
|
581 |
+
outputs=[node_dropdown],
|
582 |
+
)
|
583 |
+
# node_dropdown.change(
|
584 |
+
# debounced_update_graph,
|
585 |
+
# inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
|
586 |
+
# outputs=[graph_output, node_dropdown],
|
587 |
+
# )
|
588 |
+
|
589 |
+
# depth_slider.change(
|
590 |
+
# debounced_update_graph,
|
591 |
+
# inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
|
592 |
+
# outputs=[graph_output, node_dropdown],
|
593 |
+
# )
|
594 |
+
update_graph_button.click(
|
595 |
+
debounced_update_graph,
|
596 |
+
inputs=[
|
597 |
+
db_dropdown,
|
598 |
+
tag_type_dropdown,
|
599 |
+
highlight_input,
|
600 |
+
node_dropdown,
|
601 |
+
depth_slider,
|
602 |
+
],
|
603 |
+
outputs=[graph_output, node_dropdown],
|
604 |
+
)
|
605 |
+
canned_query_dropdown.change(
|
606 |
+
update_selected_query,
|
607 |
+
inputs=[canned_query_dropdown],
|
608 |
+
outputs=[selected_query],
|
609 |
+
)
|
610 |
+
canned_query_submit.click(
|
611 |
+
submit_canned_query,
|
612 |
+
inputs=[canned_query_dropdown, limit_input, wrap_checkbox],
|
613 |
+
outputs=output,
|
614 |
+
)
|
615 |
+
sql_submit.click(
|
616 |
+
query_db,
|
617 |
+
inputs=[sql_input, gr.Checkbox(value=True), limit_input, wrap_checkbox],
|
618 |
+
outputs=output,
|
619 |
+
)
|
620 |
+
copy_button.click(
|
621 |
+
copy_as_markdown,
|
622 |
+
inputs=[output],
|
623 |
+
outputs=[gr.Textbox(label="Markdown Output", show_copy_button=True)],
|
624 |
+
)
|
625 |
+
download_button.click(
|
626 |
+
download_as_csv, inputs=[output], outputs=[gr.File(label="CSV Output")]
|
627 |
+
)
|
628 |
+
# nl_query_submit.click(nl_to_sql, inputs=[nl_query_input], outputs=[sql_input])
|
629 |
+
|
630 |
+
return demo
|
631 |
+
|
632 |
+
|
633 |
+
demo = create_demo()
|
634 |
+
|
635 |
+
def close_db():
|
636 |
+
global db
|
637 |
+
if db is not None:
|
638 |
+
db.close()
|
639 |
+
db = None
|
640 |
+
|
641 |
+
|
642 |
+
def launch():
|
643 |
+
print("Launching Gradio app...", flush=True)
|
644 |
+
shared_demo = demo.launch(share=True, prevent_thread_lock=True)
|
645 |
+
|
646 |
+
if isinstance(shared_demo, tuple):
|
647 |
+
if len(shared_demo) >= 2:
|
648 |
+
local_url, share_url = shared_demo[:2]
|
649 |
+
else:
|
650 |
+
local_url, share_url = shared_demo[0], "N/A"
|
651 |
+
else:
|
652 |
+
local_url = getattr(shared_demo, "local_url", "N/A")
|
653 |
+
share_url = getattr(shared_demo, "share_url", "N/A")
|
654 |
+
|
655 |
+
print(f"Local URL: {local_url}", flush=True)
|
656 |
+
print(f"Shareable link: {share_url}", flush=True)
|
657 |
+
|
658 |
+
print(
|
659 |
+
"Gradio app launched.",
|
660 |
+
flush=True,
|
661 |
+
)
|
662 |
+
|
663 |
+
# Keep the script running
|
664 |
+
demo.block_thread()
|
665 |
+
|
666 |
+
|
667 |
+
if __name__ == "__main__":
|
668 |
+
launch()
|
669 |
+
|
670 |
+
# Mount the Gradio app
|
671 |
+
# app = mount_gradio_app(app, demo, path="/")
|
672 |
+
|
673 |
+
# print(f"Shareable link: {demo.share_url}")
|
674 |
+
|
675 |
+
# @app.exception_handler(Exception)
|
676 |
+
# async def exception_handler(request: Request, exc: Exception):
|
677 |
+
# print(f"An error occurred: {str(exc)}")
|
678 |
+
# return {"error": str(exc)}
|
679 |
+
|
680 |
+
# @contextmanager
|
681 |
+
# def get_db_connection():
|
682 |
+
# global db
|
683 |
+
# conn = db.conn.cursor().connection
|
684 |
+
# try:
|
685 |
+
# yield conn
|
686 |
+
# finally:
|
687 |
+
# conn.close()
|
688 |
+
|
689 |
+
# @app.on_event("startup")
|
690 |
+
# async def startup_event():
|
691 |
+
# global db
|
692 |
+
# ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
693 |
+
# db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, get_available_databases()[0]) # Use the first available database
|
694 |
+
# db = ArxivDatabase(db_path)
|
695 |
+
# db.init_db()
|
696 |
+
|
697 |
+
# @app.on_event("shutdown")
|
698 |
+
# async def shutdown_event():
|
699 |
+
# if db is not None:
|
700 |
+
# db.close()
|
701 |
+
|
702 |
+
|
703 |
+
# if __name__ == "__main__":
|
704 |
+
# uvicorn.run(app, host="0.0.0.0", port=7860)
|
scripts/run_db_interface_basic.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import networkx as nx
|
5 |
+
import pandas as pd
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
+
import sqlite3
|
10 |
+
import time
|
11 |
+
import uvicorn
|
12 |
+
|
13 |
+
from fastapi import FastAPI, Request
|
14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
15 |
+
from gradio.routes import mount_gradio_app
|
16 |
+
from plotly.subplots import make_subplots
|
17 |
+
from tabulate import tabulate
|
18 |
+
from typing import Optional
|
19 |
+
|
20 |
+
|
21 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
22 |
+
if ROOT_DIR not in sys.path:
|
23 |
+
sys.path.insert(0, ROOT_DIR)
|
24 |
+
|
25 |
+
from scripts.create_db import ArxivDatabase
|
26 |
+
from config import (
|
27 |
+
DEFAULT_TABLES_DIR,
|
28 |
+
DEFAULT_INTERFACE_MODEL_ID,
|
29 |
+
COOCCURRENCE_QUERY,
|
30 |
+
canned_queries,
|
31 |
+
)
|
32 |
+
|
33 |
+
app = FastAPI()
|
34 |
+
|
35 |
+
# Add CORS middleware
|
36 |
+
app.add_middleware(
|
37 |
+
CORSMiddleware,
|
38 |
+
allow_origins=["*"],
|
39 |
+
allow_credentials=True,
|
40 |
+
allow_methods=["*"],
|
41 |
+
allow_headers=["*"],
|
42 |
+
)
|
43 |
+
|
44 |
+
db: Optional[ArxivDatabase] = None
|
45 |
+
|
46 |
+
|
47 |
+
def truncate_or_wrap_text(text, max_length=50, wrap=False):
|
48 |
+
"""Truncate text to a maximum length, adding ellipsis if truncated, or wrap if specified."""
|
49 |
+
if wrap:
|
50 |
+
return "\n".join(
|
51 |
+
text[i : i + max_length] for i in range(0, len(text), max_length)
|
52 |
+
)
|
53 |
+
return text[:max_length] + "..." if len(text) > max_length else text
|
54 |
+
|
55 |
+
|
56 |
+
def format_url(url):
|
57 |
+
"""Format URL to be more compact in the table."""
|
58 |
+
return url.split("/")[-1] if url.startswith("http") else url
|
59 |
+
|
60 |
+
|
61 |
+
def get_available_databases():
|
62 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
63 |
+
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
|
64 |
+
return [f for f in os.listdir(tables_dir) if f.endswith(".db")]
|
65 |
+
|
66 |
+
|
67 |
+
def query_db(query, is_sql, limit=None, wrap=False):
|
68 |
+
global db
|
69 |
+
if db is None:
|
70 |
+
return pd.DataFrame({"Error": ["Please load a database first."]})
|
71 |
+
|
72 |
+
try:
|
73 |
+
cursor = db.conn.cursor()
|
74 |
+
|
75 |
+
query = " ".join(query.strip().split("\n")).rstrip(";")
|
76 |
+
|
77 |
+
if limit is not None:
|
78 |
+
if "LIMIT" in query.upper():
|
79 |
+
# Replace existing LIMIT clause
|
80 |
+
query = re.sub(
|
81 |
+
r"LIMIT\s+\d+", f"LIMIT {limit}", query, flags=re.IGNORECASE
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
query += f" LIMIT {limit}"
|
85 |
+
|
86 |
+
cursor.execute(query)
|
87 |
+
|
88 |
+
column_names = [description[0] for description in cursor.description]
|
89 |
+
|
90 |
+
results = cursor.fetchall()
|
91 |
+
|
92 |
+
df = pd.DataFrame(results, columns=column_names)
|
93 |
+
|
94 |
+
for column in df.columns:
|
95 |
+
if df[column].dtype == "object":
|
96 |
+
df[column] = df[column].apply(
|
97 |
+
lambda x: (
|
98 |
+
format_url(x)
|
99 |
+
if column == "url"
|
100 |
+
else truncate_or_wrap_text(x, wrap=wrap)
|
101 |
+
)
|
102 |
+
)
|
103 |
+
|
104 |
+
return df
|
105 |
+
|
106 |
+
except sqlite3.Error as e:
|
107 |
+
return pd.DataFrame({"Error": [f"Database error: {str(e)}"]})
|
108 |
+
except Exception as e:
|
109 |
+
return pd.DataFrame({"Error": [f"An unexpected error occurred: {str(e)}"]})
|
110 |
+
|
111 |
+
|
112 |
+
def generate_concept_cooccurrence_graph(db_path):
|
113 |
+
conn = sqlite3.connect(db_path)
|
114 |
+
df = pd.read_sql_query(COOCCURRENCE_QUERY, conn)
|
115 |
+
conn.close()
|
116 |
+
|
117 |
+
G = nx.from_pandas_edgelist(df, "concept1", "concept2", "co_occurrences")
|
118 |
+
pos = nx.spring_layout(G)
|
119 |
+
|
120 |
+
edge_x = []
|
121 |
+
edge_y = []
|
122 |
+
for edge in G.edges():
|
123 |
+
x0, y0 = pos[edge[0]]
|
124 |
+
x1, y1 = pos[edge[1]]
|
125 |
+
edge_x.extend([x0, x1, None])
|
126 |
+
edge_y.extend([y0, y1, None])
|
127 |
+
|
128 |
+
edge_trace = go.Scatter(
|
129 |
+
x=edge_x,
|
130 |
+
y=edge_y,
|
131 |
+
line=dict(width=0.5, color="#888"),
|
132 |
+
hoverinfo="none",
|
133 |
+
mode="lines",
|
134 |
+
)
|
135 |
+
|
136 |
+
node_x = [pos[node][0] for node in G.nodes()]
|
137 |
+
node_y = [pos[node][1] for node in G.nodes()]
|
138 |
+
|
139 |
+
node_trace = go.Scatter(
|
140 |
+
x=node_x,
|
141 |
+
y=node_y,
|
142 |
+
mode="markers",
|
143 |
+
hoverinfo="text",
|
144 |
+
marker=dict(
|
145 |
+
showscale=True,
|
146 |
+
colorscale="YlGnBu",
|
147 |
+
size=10,
|
148 |
+
colorbar=dict(
|
149 |
+
thickness=15,
|
150 |
+
title="Node Connections",
|
151 |
+
xanchor="left",
|
152 |
+
titleside="right",
|
153 |
+
),
|
154 |
+
),
|
155 |
+
)
|
156 |
+
|
157 |
+
node_adjacencies = []
|
158 |
+
node_text = []
|
159 |
+
for node, adjacencies in G.adjacency():
|
160 |
+
node_adjacencies.append(len(adjacencies))
|
161 |
+
node_text.append(f"{node}<br># of connections: {len(adjacencies)}")
|
162 |
+
|
163 |
+
node_trace.marker.color = node_adjacencies
|
164 |
+
node_trace.text = node_text
|
165 |
+
|
166 |
+
fig = go.Figure(
|
167 |
+
data=[edge_trace, node_trace],
|
168 |
+
layout=go.Layout(
|
169 |
+
title="Concept Co-occurrence Network",
|
170 |
+
titlefont_size=16,
|
171 |
+
showlegend=False,
|
172 |
+
hovermode="closest",
|
173 |
+
margin=dict(b=20, l=5, r=5, t=40),
|
174 |
+
annotations=[
|
175 |
+
dict(
|
176 |
+
text="",
|
177 |
+
showarrow=False,
|
178 |
+
xref="paper",
|
179 |
+
yref="paper",
|
180 |
+
x=0.005,
|
181 |
+
y=-0.002,
|
182 |
+
)
|
183 |
+
],
|
184 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
185 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
186 |
+
),
|
187 |
+
)
|
188 |
+
return fig
|
189 |
+
|
190 |
+
|
191 |
+
# def load_database_with_graphs(db_name):
|
192 |
+
# global db
|
193 |
+
# ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
194 |
+
# db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
|
195 |
+
# if not os.path.exists(db_path):
|
196 |
+
# return f"Database {db_name} does not exist.", None
|
197 |
+
# db = ArxivDatabase(db_path)
|
198 |
+
# db.init_db()
|
199 |
+
# if db.is_db_empty:
|
200 |
+
# return (
|
201 |
+
# f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
|
202 |
+
# None,
|
203 |
+
# )
|
204 |
+
|
205 |
+
# # Generate graph
|
206 |
+
# graph = generate_concept_cooccurrence_graph(db_path)
|
207 |
+
|
208 |
+
# return f"Database loaded from {db_path}", graph
|
209 |
+
|
210 |
+
|
211 |
+
def load_database_with_graphs(db_name):
|
212 |
+
global db
|
213 |
+
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
214 |
+
db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
|
215 |
+
if not os.path.exists(db_path):
|
216 |
+
return f"Database {db_name} does not exist.", None
|
217 |
+
|
218 |
+
if db is None or db.db_path != db_path:
|
219 |
+
db = ArxivDatabase(db_path)
|
220 |
+
db.init_db()
|
221 |
+
|
222 |
+
if db.is_db_empty:
|
223 |
+
return (
|
224 |
+
f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
|
225 |
+
None,
|
226 |
+
)
|
227 |
+
|
228 |
+
graph = generate_concept_cooccurrence_graph(db_path)
|
229 |
+
return f"Database loaded from {db_path}", graph
|
230 |
+
|
231 |
+
|
232 |
+
css = """
|
233 |
+
#selected-query {
|
234 |
+
max-height: 100px;
|
235 |
+
overflow-y: auto;
|
236 |
+
white-space: pre-wrap;
|
237 |
+
word-break: break-word;
|
238 |
+
}
|
239 |
+
"""
|
240 |
+
|
241 |
+
|
242 |
+
def create_demo():
|
243 |
+
with gr.Blocks(css=css) as demo:
|
244 |
+
gr.Markdown("# ArXiv Database Query Interface")
|
245 |
+
|
246 |
+
with gr.Row():
|
247 |
+
db_dropdown = gr.Dropdown(
|
248 |
+
choices=get_available_databases(), label="Select Database"
|
249 |
+
)
|
250 |
+
load_db_btn = gr.Button("Load Database", size="sm")
|
251 |
+
status = gr.Textbox(label="Status")
|
252 |
+
|
253 |
+
with gr.Row():
|
254 |
+
graph_output = gr.Plot(label="Concept Co-occurrence Graph")
|
255 |
+
|
256 |
+
with gr.Row():
|
257 |
+
wrap_checkbox = gr.Checkbox(label="Wrap long text", value=False)
|
258 |
+
canned_query_dropdown = gr.Dropdown(
|
259 |
+
choices=[q[0] for q in canned_queries], label="Select Query", scale=3
|
260 |
+
)
|
261 |
+
limit_input = gr.Number(
|
262 |
+
label="Limit", value=10000, step=1, minimum=1, scale=1
|
263 |
+
)
|
264 |
+
selected_query = gr.Textbox(
|
265 |
+
label="Selected Query",
|
266 |
+
interactive=False,
|
267 |
+
scale=2,
|
268 |
+
show_label=True,
|
269 |
+
show_copy_button=True,
|
270 |
+
elem_id="selected-query",
|
271 |
+
)
|
272 |
+
canned_query_submit = gr.Button("Submit Query", size="sm", scale=1)
|
273 |
+
|
274 |
+
with gr.Row():
|
275 |
+
sql_input = gr.Textbox(label="Custom SQL Query", lines=3, scale=4)
|
276 |
+
sql_submit = gr.Button("Submit Custom SQL", size="sm", scale=1)
|
277 |
+
|
278 |
+
output = gr.DataFrame(label="Results", wrap=True)
|
279 |
+
|
280 |
+
def update_selected_query(query_description):
|
281 |
+
for desc, sql in canned_queries:
|
282 |
+
if desc == query_description:
|
283 |
+
return sql
|
284 |
+
return ""
|
285 |
+
|
286 |
+
def submit_canned_query(query_description, limit, wrap):
|
287 |
+
for desc, sql in canned_queries:
|
288 |
+
if desc == query_description:
|
289 |
+
return query_db(sql, True, limit, wrap)
|
290 |
+
return pd.DataFrame({"Error": ["Selected query not found."]})
|
291 |
+
|
292 |
+
load_db_btn.click(
|
293 |
+
load_database_with_graphs,
|
294 |
+
inputs=[db_dropdown],
|
295 |
+
outputs=[status, graph_output],
|
296 |
+
)
|
297 |
+
canned_query_dropdown.change(
|
298 |
+
update_selected_query,
|
299 |
+
inputs=[canned_query_dropdown],
|
300 |
+
outputs=[selected_query],
|
301 |
+
)
|
302 |
+
canned_query_submit.click(
|
303 |
+
submit_canned_query,
|
304 |
+
inputs=[canned_query_dropdown, limit_input, wrap_checkbox],
|
305 |
+
outputs=output,
|
306 |
+
)
|
307 |
+
sql_submit.click(
|
308 |
+
query_db,
|
309 |
+
inputs=[sql_input, gr.Checkbox(value=True), limit_input, wrap_checkbox],
|
310 |
+
outputs=output,
|
311 |
+
)
|
312 |
+
|
313 |
+
return demo
|
314 |
+
|
315 |
+
|
316 |
+
demo = create_demo()
|
317 |
+
|
318 |
+
|
319 |
+
def close_db():
|
320 |
+
global db
|
321 |
+
if db is not None:
|
322 |
+
db.close()
|
323 |
+
db = None
|
324 |
+
|
325 |
+
|
326 |
+
# def launch():
|
327 |
+
# print("Launching Gradio app...", flush=True)
|
328 |
+
# demo.launch(share=True)
|
329 |
+
# print(
|
330 |
+
# "Gradio app launched. If you don't see a URL above, there might be network restrictions.",
|
331 |
+
# flush=True,
|
332 |
+
# )
|
333 |
+
|
334 |
+
# close_db()
|
335 |
+
|
336 |
+
# if __name__ == "__main__":
|
337 |
+
# launch()
|
338 |
+
|
339 |
+
# Mount the Gradio app
|
340 |
+
app = mount_gradio_app(app, demo, path="/")
|
341 |
+
|
342 |
+
|
343 |
+
@app.exception_handler(Exception)
|
344 |
+
async def exception_handler(request: Request, exc: Exception):
|
345 |
+
print(f"An error occurred: {str(exc)}")
|
346 |
+
return {"error": str(exc)}
|
347 |
+
|
348 |
+
|
349 |
+
@app.on_event("startup")
|
350 |
+
async def startup_event():
|
351 |
+
# You can initialize the database here if needed
|
352 |
+
pass
|
353 |
+
|
354 |
+
|
355 |
+
@app.on_event("shutdown")
|
356 |
+
async def shutdown_event():
|
357 |
+
close_db()
|
358 |
+
|
359 |
+
|
360 |
+
if __name__ == "__main__":
|
361 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
scripts/run_db_interface_js.py
ADDED
File without changes
|