Abhipsha Das commited on
Commit
8fbb714
Β·
unverified Β·
1 Parent(s): 3e39d34

initial spaces deploy

Browse files
README.md CHANGED
@@ -1,14 +1,19 @@
1
  ---
2
- title: Surveyor 0
3
- emoji: πŸ‘€
4
- colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.9.0
8
  app_file: app.py
9
  pinned: false
10
- license: openrail
11
- short_description: Interface for exploring scientific concepts with KGs
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Surveyor
3
+ emoji: πŸ”
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ # Surveyor
13
+
14
+ An interactive interface for querying and visualizing scientific paper databases with concept co-occurrence graphs.
15
+ ## Features
16
+ - Interactive concept co-occurrence graphs
17
+ - SQL query interface with pre-built queries
18
+ - Support for multiple scientific domains
19
+ - Graph filtering and highlighting
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ # Add the project root directory to Python path
5
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
6
+ if ROOT_DIR not in sys.path:
7
+ sys.path.insert(0, ROOT_DIR)
8
+
9
+ from scripts.run_db_interface import create_demo
10
+
11
+ demo = create_demo()
12
+
13
+ if __name__ == "__main__":
14
+ demo.launch()
config.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_MODEL_ID = "Meta-Llama-3-70B-Instruct"
2
+ DEFAULT_INTERFACE_MODEL_ID = "NumbersStation/nsql-llama-2-7B"
3
+ DEFAULT_KIND = "json"
4
+ DEFAULT_TEMPERATURE = 0.6
5
+ DEFAULT_TOP_P = 0.95
6
+ DEFAULT_FEW_SHOT_NUM = 3
7
+ DEFAULT_FEW_SHOT_SELECTION = "random"
8
+ DEFAULT_SAVE_INTERVAL = 3
9
+ DEFAULT_RES_DIR = "data/results"
10
+ DEFAULT_LOG_DIR = "logs"
11
+ DEFAULT_TABLES_DIR = "data/databases"
12
+
13
+ COOCCURRENCE_QUERY = """
14
+ WITH concept_pairs AS (
15
+ SELECT p1.concept AS concept1, p2.concept AS concept2, p1.paper_id, p1.tag_type
16
+ FROM predictions p1
17
+ JOIN predictions p2 ON p1.paper_id = p2.paper_id AND p1.concept < p2.concept
18
+ WHERE p1.tag_type = p2.tag_type
19
+ )
20
+ SELECT concept1, concept2, tag_type, COUNT(DISTINCT paper_id) AS co_occurrences
21
+ FROM concept_pairs
22
+ GROUP BY concept1, concept2, tag_type
23
+ HAVING co_occurrences > 5
24
+ ORDER BY co_occurrences DESC;
25
+ """
26
+
27
+ canned_queries = [
28
+ (
29
+ "Modalities in Physics and Astronomy papers",
30
+ """
31
+ SELECT DISTINCT LOWER(concept) AS concept
32
+ FROM predictions
33
+ JOIN (
34
+ SELECT paper_id, url
35
+ FROM papers
36
+ WHERE primary_category LIKE '%physics.space-ph%'
37
+ OR primary_category LIKE '%astro-ph.%'
38
+ ) AS paper_ids
39
+ ON predictions.paper_id = paper_ids.paper_id
40
+ WHERE predictions.tag_type = 'modality'
41
+ """,
42
+ ),
43
+ (
44
+ "Datasets in Evolutionary Biology that use PDEs",
45
+ """
46
+ WITH pde_predictions AS (
47
+ SELECT paper_id, concept AS pde_concept, tag_type AS pde_tag_type
48
+ FROM predictions
49
+ WHERE tag_type IN ('method', 'model')
50
+ AND (
51
+ LOWER(concept) LIKE '%pde%'
52
+ OR LOWER(concept) LIKE '%partial differential equation%'
53
+ )
54
+ )
55
+ SELECT DISTINCT
56
+ papers.paper_id,
57
+ papers.url,
58
+ LOWER(p_dataset.concept) AS dataset,
59
+ pde_predictions.pde_concept AS pde_related_concept,
60
+ pde_predictions.pde_tag_type AS pde_related_type
61
+ FROM papers
62
+ JOIN pde_predictions ON papers.paper_id = pde_predictions.paper_id
63
+ LEFT JOIN predictions p_dataset ON papers.paper_id = p_dataset.paper_id
64
+ WHERE papers.primary_category LIKE '%q-bio.PE%'
65
+ AND (p_dataset.tag_type = 'dataset' OR p_dataset.tag_type IS NULL)
66
+ ORDER BY papers.paper_id, dataset, pde_related_concept;
67
+ """,
68
+ ),
69
+ (
70
+ "Trends in objects of study in Cosmology since 2019",
71
+ """
72
+
73
+ SELECT
74
+ substr(papers.updated_on, 2, 4) as year,
75
+ predictions.concept as object,
76
+ COUNT(DISTINCT papers.paper_id) as paper_count
77
+ FROM
78
+ papers
79
+ JOIN
80
+ predictions ON papers.paper_id = predictions.paper_id
81
+ WHERE
82
+ predictions.tag_type = 'object'
83
+ AND CAST(SUBSTR(papers.updated_on, 2, 4) AS INTEGER) >= 2019
84
+ GROUP BY
85
+ year, object
86
+ ORDER BY
87
+ year DESC, paper_count DESC;
88
+ """,
89
+ ),
90
+ (
91
+ "New datasets in fluid dynamics since 2020",
92
+ """
93
+ WITH ranked_datasets AS (
94
+ SELECT
95
+ p.paper_id,
96
+ p.url,
97
+ pred.concept AS dataset,
98
+ p.updated_on,
99
+ ROW_NUMBER() OVER (PARTITION BY pred.concept ORDER BY p.updated_on ASC) AS rn
100
+ FROM
101
+ papers p
102
+ JOIN
103
+ predictions pred ON p.paper_id = pred.paper_id
104
+ WHERE
105
+ pred.tag_type = 'dataset'
106
+ AND p.primary_category LIKE '%physics.flu-dyn%'
107
+ AND CAST(SUBSTR(p.updated_on, 2, 4) AS INTEGER) >= 2020
108
+ )
109
+ SELECT
110
+ paper_id,
111
+ url,
112
+ dataset,
113
+ updated_on
114
+ FROM
115
+ ranked_datasets
116
+ WHERE
117
+ rn = 1
118
+ ORDER BY
119
+ updated_on ASC
120
+ """,
121
+ ),
122
+ (
123
+ "Evolutionary biology datasets that use spatiotemporal dynamics",
124
+ """
125
+ WITH evo_bio_papers AS (
126
+ SELECT paper_id
127
+ FROM papers
128
+ WHERE primary_category LIKE '%q-bio.PE%'
129
+ ),
130
+ spatiotemporal_keywords AS (
131
+ SELECT 'spatio-temporal' AS keyword
132
+ UNION SELECT 'spatiotemporal'
133
+ UNION SELECT 'spatio-temporal'
134
+ UNION SELECT 'spatial and temporal'
135
+ UNION SELECT 'space-time'
136
+ UNION SELECT 'geographic distribution'
137
+ UNION SELECT 'phylogeograph'
138
+ UNION SELECT 'biogeograph'
139
+ UNION SELECT 'dispersal'
140
+ UNION SELECT 'migration'
141
+ UNION SELECT 'range expansion'
142
+ UNION SELECT 'population dynamics'
143
+ )
144
+ SELECT DISTINCT
145
+ p.paper_id,
146
+ p.updated_on,
147
+ p.abstract,
148
+ d.concept AS dataset,
149
+ GROUP_CONCAT(DISTINCT stk.keyword) AS spatiotemporal_keywords_found
150
+ FROM
151
+ evo_bio_papers ebp
152
+ JOIN
153
+ papers p ON ebp.paper_id = p.paper_id
154
+ JOIN
155
+ predictions d ON p.paper_id = d.paper_id
156
+ JOIN
157
+ predictions st ON p.paper_id = st.paper_id
158
+ JOIN
159
+ spatiotemporal_keywords stk
160
+ WHERE
161
+ d.tag_type = 'dataset'
162
+ AND st.tag_type = 'modality'
163
+ AND LOWER(st.concept) LIKE '%' || stk.keyword || '%'
164
+ GROUP BY
165
+ p.paper_id, p.updated_on, p.abstract, d.concept
166
+ ORDER BY
167
+ p.updated_on DESC
168
+ """,
169
+ ),
170
+ (
171
+ "What percentage of papers use only galaxy or spectra, or both or neither?",
172
+ """
173
+ WITH paper_modalities AS (
174
+ SELECT
175
+ p.paper_id,
176
+ MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_galaxy_images,
177
+ MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra
178
+ FROM
179
+ papers p
180
+ LEFT JOIN
181
+ predictions pred ON p.paper_id = pred.paper_id
182
+ WHERE
183
+ p.primary_category LIKE '%astro-ph%'
184
+ AND pred.tag_type = 'modality'
185
+ GROUP BY
186
+ p.paper_id
187
+ ),
188
+ categorized_papers AS (
189
+ SELECT
190
+ CASE
191
+ WHEN uses_galaxy_images = 1 AND uses_spectra = 1 THEN 'Both'
192
+ WHEN uses_galaxy_images = 1 THEN 'Only Galaxy Images'
193
+ WHEN uses_spectra = 1 THEN 'Only Spectra'
194
+ ELSE 'Neither'
195
+ END AS category,
196
+ COUNT(*) AS paper_count
197
+ FROM
198
+ paper_modalities
199
+ GROUP BY
200
+ CASE
201
+ WHEN uses_galaxy_images = 1 AND uses_spectra = 1 THEN 'Both'
202
+ WHEN uses_galaxy_images = 1 THEN 'Only Galaxy Images'
203
+ WHEN uses_spectra = 1 THEN 'Only Spectra'
204
+ ELSE 'Neither'
205
+ END
206
+ )
207
+ SELECT
208
+ category,
209
+ paper_count,
210
+ ROUND(CAST(paper_count AS FLOAT) / (SELECT SUM(paper_count) FROM categorized_papers) * 100, 2) AS percentage
211
+ FROM
212
+ categorized_papers
213
+ ORDER BY
214
+ paper_count DESC
215
+ """,
216
+ ),
217
+ (
218
+ "What are all the next highest data modalities after images and spectra?",
219
+ """
220
+ SELECT
221
+ LOWER(concept) AS modality,
222
+ COUNT(DISTINCT paper_id) AS usage_count
223
+ FROM
224
+ predictions
225
+ WHERE
226
+ tag_type = 'modality'
227
+ AND LOWER(concept) NOT LIKE '%imag%'
228
+ AND LOWER(concept) NOT LIKE '%spectr%'
229
+ GROUP BY
230
+ LOWER(concept)
231
+ ORDER BY
232
+ usage_count DESC
233
+ """,
234
+ ),
235
+ (
236
+ "If we include the next biggest data modality, how much does coverage change?",
237
+ """
238
+ WITH modality_counts AS (
239
+ SELECT
240
+ LOWER(concept) AS modality,
241
+ COUNT(DISTINCT paper_id) AS usage_count
242
+ FROM
243
+ predictions
244
+ WHERE
245
+ tag_type = 'modality'
246
+ AND LOWER(concept) NOT LIKE '%imag%'
247
+ AND LOWER(concept) NOT LIKE '%spectr%'
248
+ GROUP BY
249
+ LOWER(concept)
250
+ ORDER BY
251
+ usage_count DESC
252
+ LIMIT 1
253
+ ),
254
+ paper_modalities AS (
255
+ SELECT
256
+ p.paper_id,
257
+ MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_galaxy_images,
258
+ MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra,
259
+ MAX(CASE WHEN LOWER(pred.concept) LIKE (SELECT '%' || modality || '%' FROM modality_counts) THEN 1 ELSE 0 END) AS uses_third_modality
260
+ FROM
261
+ papers p
262
+ LEFT JOIN
263
+ predictions pred ON p.paper_id = pred.paper_id
264
+ WHERE
265
+ p.primary_category LIKE '%astro-ph%'
266
+ AND pred.tag_type = 'modality'
267
+ GROUP BY
268
+ p.paper_id
269
+ ),
270
+ coverage_before AS (
271
+ SELECT
272
+ SUM(CASE WHEN uses_galaxy_images = 1 OR uses_spectra = 1 THEN 1 ELSE 0 END) AS covered_papers,
273
+ COUNT(*) AS total_papers
274
+ FROM
275
+ paper_modalities
276
+ ),
277
+ coverage_after AS (
278
+ SELECT
279
+ SUM(CASE WHEN uses_galaxy_images = 1 OR uses_spectra = 1 OR uses_third_modality = 1 THEN 1 ELSE 0 END) AS covered_papers,
280
+ COUNT(*) AS total_papers
281
+ FROM
282
+ paper_modalities
283
+ )
284
+ SELECT
285
+ (SELECT modality FROM modality_counts) AS third_modality,
286
+ ROUND(CAST(covered_papers AS FLOAT) / total_papers * 100, 2) AS coverage_before_percent,
287
+ ROUND(CAST((SELECT covered_papers FROM coverage_after) AS FLOAT) / total_papers * 100, 2) AS coverage_after_percent,
288
+ ROUND(CAST((SELECT covered_papers FROM coverage_after) AS FLOAT) / total_papers * 100, 2) -
289
+ ROUND(CAST(covered_papers AS FLOAT) / total_papers * 100, 2) AS coverage_increase_percent
290
+ FROM
291
+ coverage_before
292
+ """,
293
+ ),
294
+ (
295
+ "Coverage if we select the next 5 highest modalities?",
296
+ """
297
+ WITH ranked_modalities AS (
298
+ SELECT
299
+ LOWER(concept) AS modality,
300
+ COUNT(DISTINCT paper_id) AS usage_count,
301
+ ROW_NUMBER() OVER (ORDER BY COUNT(DISTINCT paper_id) DESC) AS rank
302
+ FROM
303
+ predictions
304
+ WHERE
305
+ tag_type = 'modality'
306
+ AND LOWER(concept) NOT LIKE '%imag%'
307
+ AND LOWER(concept) NOT LIKE '%spectr%'
308
+ GROUP BY
309
+ LOWER(concept)
310
+ ),
311
+ paper_modalities AS (
312
+ SELECT
313
+ p.paper_id,
314
+ MAX(CASE WHEN LOWER(pred.concept) LIKE '%imag%' THEN 1 ELSE 0 END) AS uses_images,
315
+ MAX(CASE WHEN LOWER(pred.concept) LIKE '%spectr%' THEN 1 ELSE 0 END) AS uses_spectra,
316
+ MAX(CASE WHEN rm.rank = 1 THEN 1 ELSE 0 END) AS uses_modality_1,
317
+ MAX(CASE WHEN rm.rank = 2 THEN 1 ELSE 0 END) AS uses_modality_2,
318
+ MAX(CASE WHEN rm.rank = 3 THEN 1 ELSE 0 END) AS uses_modality_3,
319
+ MAX(CASE WHEN rm.rank = 4 THEN 1 ELSE 0 END) AS uses_modality_4,
320
+ MAX(CASE WHEN rm.rank = 5 THEN 1 ELSE 0 END) AS uses_modality_5
321
+ FROM
322
+ papers p
323
+ LEFT JOIN
324
+ predictions pred ON p.paper_id = pred.paper_id
325
+ LEFT JOIN
326
+ ranked_modalities rm ON LOWER(pred.concept) = rm.modality
327
+ WHERE
328
+ p.primary_category LIKE '%astro-ph%'
329
+ AND pred.tag_type = 'modality'
330
+ GROUP BY
331
+ p.paper_id
332
+ ),
333
+ cumulative_coverage AS (
334
+ SELECT
335
+ 'Images and Spectra' AS modalities,
336
+ 0 AS added_modality_rank,
337
+ SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 THEN 1 ELSE 0 END) AS covered_papers,
338
+ COUNT(*) AS total_papers
339
+ FROM
340
+ paper_modalities
341
+
342
+ UNION ALL
343
+
344
+ SELECT
345
+ 'Images, Spectra, and Modality 1' AS modalities,
346
+ 1 AS added_modality_rank,
347
+ SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 THEN 1 ELSE 0 END) AS covered_papers,
348
+ COUNT(*) AS total_papers
349
+ FROM
350
+ paper_modalities
351
+
352
+ UNION ALL
353
+
354
+ SELECT
355
+ 'Images, Spectra, Modality 1, and 2' AS modalities,
356
+ 2 AS added_modality_rank,
357
+ SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 THEN 1 ELSE 0 END) AS covered_papers,
358
+ COUNT(*) AS total_papers
359
+ FROM
360
+ paper_modalities
361
+
362
+ UNION ALL
363
+
364
+ SELECT
365
+ 'Images, Spectra, Modality 1, 2, and 3' AS modalities,
366
+ 3 AS added_modality_rank,
367
+ SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 THEN 1 ELSE 0 END) AS covered_papers,
368
+ COUNT(*) AS total_papers
369
+ FROM
370
+ paper_modalities
371
+
372
+ UNION ALL
373
+
374
+ SELECT
375
+ 'Images, Spectra, Modality 1, 2, 3, and 4' AS modalities,
376
+ 4 AS added_modality_rank,
377
+ SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 OR uses_modality_4 = 1 THEN 1 ELSE 0 END) AS covered_papers,
378
+ COUNT(*) AS total_papers
379
+ FROM
380
+ paper_modalities
381
+
382
+ UNION ALL
383
+
384
+ SELECT
385
+ 'Images, Spectra, Modality 1, 2, 3, 4, and 5' AS modalities,
386
+ 5 AS added_modality_rank,
387
+ SUM(CASE WHEN uses_images = 1 OR uses_spectra = 1 OR uses_modality_1 = 1 OR uses_modality_2 = 1 OR uses_modality_3 = 1 OR uses_modality_4 = 1 OR uses_modality_5 = 1 THEN 1 ELSE 0 END) AS covered_papers,
388
+ COUNT(*) AS total_papers
389
+ FROM
390
+ paper_modalities
391
+ )
392
+ SELECT
393
+ cc.modalities,
394
+ COALESCE(rm.modality, 'N/A') AS added_modality,
395
+ rm.usage_count AS added_modality_usage,
396
+ ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2) AS coverage_percent,
397
+ ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2) -
398
+ LAG(ROUND(CAST(cc.covered_papers AS FLOAT) / cc.total_papers * 100, 2), 1, 0) OVER (ORDER BY cc.added_modality_rank) AS coverage_increase_percent
399
+ FROM
400
+ cumulative_coverage cc
401
+ LEFT JOIN
402
+ ranked_modalities rm ON cc.added_modality_rank = rm.rank
403
+ ORDER BY
404
+ cc.added_modality_rank
405
+ """,
406
+ ),
407
+ (
408
+ "List all papers",
409
+ "SELECT paper_id, abstract AS abstract_preview, authors, primary_category FROM papers",
410
+ ),
411
+ (
412
+ "Count papers by category",
413
+ "SELECT primary_category, COUNT(*) as paper_count FROM papers GROUP BY primary_category ORDER BY paper_count DESC",
414
+ ),
415
+ (
416
+ "Top authors with most papers",
417
+ """
418
+ WITH author_papers AS (
419
+ SELECT json_each.value AS author
420
+ FROM papers, json_each(papers.authors)
421
+ )
422
+ SELECT author, COUNT(*) as paper_count
423
+ FROM author_papers
424
+ GROUP BY author
425
+ ORDER BY paper_count DESC
426
+ """,
427
+ ),
428
+ (
429
+ "Papers with 'quantum' in abstract",
430
+ "SELECT paper_id, abstract AS abstract_preview FROM papers WHERE abstract LIKE '%quantum%'",
431
+ ),
432
+ (
433
+ "Most common concepts",
434
+ "SELECT concept, COUNT(*) as concept_count FROM predictions GROUP BY concept ORDER BY concept_count DESC",
435
+ ),
436
+ (
437
+ "Papers with multiple authors",
438
+ """
439
+ SELECT paper_id, json_array_length(authors) as author_count, authors
440
+ FROM papers
441
+ WHERE json_array_length(authors) > 1
442
+ ORDER BY author_count DESC
443
+ """,
444
+ ),
445
+ ]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==4.40.0
2
+ networkx==3.3
3
+ pandas==2.2.2
4
+ plotly==5.23.0
5
+ tabulate==0.9.0
6
+ fastapi==0.104.1
7
+ pydantic==2.5.3
8
+ uvicorn==0.27.1
scripts/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
+ if ROOT_DIR not in sys.path:
6
+ sys.path.insert(0, ROOT_DIR)
scripts/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (566 Bytes). View file
 
scripts/__pycache__/create_db.cpython-311.pyc ADDED
Binary file (14.6 kB). View file
 
scripts/__pycache__/run_db_interface.cpython-311.pyc ADDED
Binary file (29 kB). View file
 
scripts/__pycache__/run_db_interface_improved.cpython-311.pyc ADDED
Binary file (29.2 kB). View file
 
scripts/create_db.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import json
3
+ import os
4
+ import sqlite3
5
+ import sys
6
+
7
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
8
+ from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID
9
+ from src.processing.generate import get_sentences, generate_prediction
10
+ from src.utils.utils import load_model_and_tokenizer
11
+
12
+
13
+ class ArxivDatabase:
14
+ def __init__(self, db_path, model_id=None):
15
+ self.conn = None
16
+ self.cursor = None
17
+ self.db_path = db_path
18
+ self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID
19
+ self.model = None
20
+ self.tokenizer = None
21
+ self.is_db_empty = True
22
+ self.paper_table = """CREATE TABLE IF NOT EXISTS papers
23
+ (paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT,
24
+ primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)"""
25
+ self.pred_table = """CREATE TABLE IF NOT EXISTS predictions
26
+ (id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER,
27
+ tag_type TEXT, concept TEXT,
28
+ FOREIGN KEY (paper_id) REFERENCES papers(paper_id))"""
29
+
30
+ # def init_db(self):
31
+ # self.cursor.execute(self.paper_table)
32
+ # self.cursor.execute(self.pred_table)
33
+
34
+ # print("Database and tables created successfully.")
35
+ # self.is_db_empty = self.is_empty()
36
+
37
+ def init_db(self):
38
+ self.conn = sqlite3.connect(self.db_path)
39
+ self.cursor = self.conn.cursor()
40
+ self.cursor.execute(self.paper_table)
41
+ self.cursor.execute(self.pred_table)
42
+ self.conn.commit()
43
+ self.is_db_empty = self.is_empty()
44
+ if not self.is_db_empty:
45
+ print("Database already contains data.")
46
+ else:
47
+ print("Database and tables created successfully.")
48
+
49
+ def is_empty(self):
50
+ try:
51
+ self.cursor.execute("SELECT COUNT(*) FROM papers")
52
+ count = self.cursor.fetchone()[0]
53
+ return count == 0
54
+ except sqlite3.OperationalError:
55
+ return True
56
+
57
+ def get_connection(self):
58
+ return sqlite3.connect(self.conn.path)
59
+
60
+ def populate_db(self, data_path, pred_path):
61
+ papers_info = self._insert_papers(data_path)
62
+ self._insert_predictions(pred_path, papers_info)
63
+ print("Database population completed.")
64
+
65
+ def _insert_papers(self, data_path):
66
+ papers_info = []
67
+ seen_papers = set()
68
+ with open(data_path, "r") as f:
69
+ for line in f:
70
+ paper = json.loads(line)
71
+ if paper["id"] in seen_papers:
72
+ continue
73
+ seen_papers.add(paper["id"])
74
+ sentence_count = len(get_sentences(paper["id"])) + len(
75
+ get_sentences(paper["abstract"])
76
+ )
77
+ papers_info.append((paper["id"], sentence_count))
78
+ self.cursor.execute(
79
+ """INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""",
80
+ (
81
+ paper["id"],
82
+ paper["abstract"],
83
+ json.dumps(paper["authors"]),
84
+ json.dumps(paper["primary_category"]),
85
+ json.dumps(paper["url"]),
86
+ json.dumps(paper["updated"]),
87
+ sentence_count,
88
+ ),
89
+ )
90
+ print(f"Inserted {len(papers_info)} papers.")
91
+ return papers_info
92
+
93
+ def _insert_predictions(self, pred_path, papers_info):
94
+ with open(pred_path, "r") as f:
95
+ predictions = json.load(f)
96
+ predicted_tags = predictions["predicted_tags"]
97
+
98
+ k = 0
99
+ papers_with_predictions = set()
100
+ papers_without_predictions = []
101
+ for paper_id, sentence_count in papers_info:
102
+ paper_predictions = predicted_tags[k : k + sentence_count]
103
+
104
+ has_predictions = False
105
+ for sentence_index, pred in enumerate(paper_predictions):
106
+ if pred: # If the prediction is not an empty dictionary
107
+ has_predictions = True
108
+ for tag_type, concepts in pred.items():
109
+ for concept in concepts:
110
+ self.cursor.execute(
111
+ """INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
112
+ VALUES (?, ?, ?, ?)""",
113
+ (paper_id, sentence_index, tag_type, concept),
114
+ )
115
+ else:
116
+ # Insert a null prediction to ensure the paper is counted
117
+ self.cursor.execute(
118
+ """INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
119
+ VALUES (?, ?, ?, ?)""",
120
+ (paper_id, sentence_index, "null", "null"),
121
+ )
122
+
123
+ if has_predictions:
124
+ papers_with_predictions.add(paper_id)
125
+ else:
126
+ papers_without_predictions.append(paper_id)
127
+
128
+ k += sentence_count
129
+
130
+ print(f"Inserted predictions for {len(papers_with_predictions)} papers.")
131
+ print(f"Papers without any predictions: {len(papers_without_predictions)}")
132
+
133
+ if k < len(predicted_tags):
134
+ print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.")
135
+
136
+ def load_model(self):
137
+ if self.model is None:
138
+ try:
139
+ self.model, self.tokenizer = load_model_and_tokenizer(self.model_id)
140
+ return f"Model {self.model_id} loaded successfully."
141
+ except Exception as e:
142
+ return f"Error loading model: {str(e)}"
143
+ else:
144
+ return "Model is already loaded."
145
+
146
+ def natural_language_to_sql(self, question):
147
+ system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers."
148
+ table = self.paper_table + "; " + self.pred_table
149
+ prefix = (
150
+ f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using "
151
+ f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` "
152
+ )
153
+
154
+ sql_query = generate_prediction(
155
+ self.model, self.tokenizer, prefix, question, "sql", system_prompt
156
+ )
157
+
158
+ sql_query = sql_query.split("```")[1]
159
+
160
+ return sql_query
161
+
162
+ def execute_query(self, sql_query):
163
+ try:
164
+ self.cursor.execute(sql_query)
165
+ results = self.cursor.fetchall()
166
+ return results if results else []
167
+ except sqlite3.Error as e:
168
+ return [(f"An error occurred: {e}",)]
169
+
170
+ def query_db(self, question, is_sql):
171
+ if self.is_db_empty:
172
+ return "The database is empty. Please populate it with data first."
173
+
174
+ try:
175
+ if is_sql:
176
+ sql_query = question.strip()
177
+ else:
178
+ nl_to_sql = self.natural_language_to_sql(question)
179
+ sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip()
180
+
181
+ results = self.execute_query(sql_query)
182
+
183
+ output = f"SQL Query: {sql_query}\n\nResults:\n"
184
+ if isinstance(results, list):
185
+ if len(results) > 0:
186
+ for row in results:
187
+ output += str(row) + "\n"
188
+ else:
189
+ output += "No results found."
190
+ else:
191
+ output += str(results) # In case of an error message
192
+
193
+ return output
194
+ except Exception as e:
195
+ return f"An error occurred: {str(e)}"
196
+
197
+ def close(self):
198
+ self.conn.commit()
199
+ self.conn.close()
200
+
201
+
202
+ def check_db_exists(db_path):
203
+ return os.path.exists(db_path) and os.path.getsize(db_path) > 0
204
+
205
+
206
+ @click.command()
207
+ @click.option(
208
+ "--data_path", help="Path to the data file containing the papers information."
209
+ )
210
+ @click.option("--pred_path", help="Path to the predictions file.")
211
+ @click.option("--db_name", default="arxiv.db", help="Name of the database to create.")
212
+ @click.option(
213
+ "--force", is_flag=True, help="Force overwrite if database already exists"
214
+ )
215
+ def main(data_path, pred_path, db_name, force):
216
+ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
217
+ tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
218
+ os.makedirs(tables_dir, exist_ok=True)
219
+ db_path = os.path.join(tables_dir, db_name)
220
+
221
+ db_exists = check_db_exists(db_path)
222
+
223
+ db = ArxivDatabase(db_path)
224
+ db.init_db()
225
+
226
+ if db_exists and not db.is_db_empty:
227
+ if not force:
228
+ print(f"Warning: The database '{db_name}' already exists and is not empty.")
229
+ overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip()
230
+ if overwrite != "y":
231
+ print("Operation cancelled.")
232
+ db.close()
233
+ return
234
+ else:
235
+ print(
236
+ f"Warning: Overwriting existing database '{db_name}' due to --force flag."
237
+ )
238
+
239
+ db.populate_db(data_path, pred_path)
240
+ db.close()
241
+
242
+ print(f"Database created and populated at: {db_path}")
243
+
244
+
245
+ if __name__ == "__main__":
246
+ main()
scripts/run_db_interface.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import networkx as nx
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ import re
8
+ import sys
9
+ import sqlite3
10
+ import tempfile
11
+ import time
12
+ import uvicorn
13
+
14
+ from contextlib import contextmanager
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from gradio.routes import mount_gradio_app
18
+ from plotly.subplots import make_subplots
19
+ from tabulate import tabulate
20
+ from typing import Optional
21
+
22
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
23
+ if ROOT_DIR not in sys.path:
24
+ sys.path.insert(0, ROOT_DIR)
25
+
26
+ from scripts.create_db import ArxivDatabase
27
+ from config import (
28
+ DEFAULT_TABLES_DIR,
29
+ DEFAULT_INTERFACE_MODEL_ID,
30
+ COOCCURRENCE_QUERY,
31
+ canned_queries,
32
+ )
33
+
34
+ app = FastAPI()
35
+
36
+ # Add CORS middleware
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"],
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+ db: Optional[ArxivDatabase] = None
46
+
47
+ last_update_time = 0
48
+ update_delay = 0.5 # Delay in seconds
49
+
50
+
51
+ def truncate_or_wrap_text(text, max_length=50, wrap=False):
52
+ """Truncate text to a maximum length, adding ellipsis if truncated, or wrap if specified."""
53
+ if wrap:
54
+ return "\n".join(
55
+ text[i : i + max_length] for i in range(0, len(text), max_length)
56
+ )
57
+ return text[:max_length] + "..." if len(text) > max_length else text
58
+
59
+
60
+ def format_url(url):
61
+ """Format URL to be more compact in the table."""
62
+ return url.split("/")[-1] if url.startswith("http") else url
63
+
64
+
65
+ def get_db_path():
66
+ """Get the database directory path based on environment"""
67
+ # First try local path
68
+ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
69
+ tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
70
+
71
+ if not os.path.exists(tables_dir):
72
+ # If running on Spaces, try the root directory
73
+ tables_dir = os.path.join(ROOT, "data", "databases")
74
+ if not os.path.exists(tables_dir):
75
+ print(f"No database directory found")
76
+ return None
77
+
78
+ print(f"Using database directory: {tables_dir}")
79
+ return tables_dir
80
+
81
+
82
+ def get_available_databases():
83
+ """Get available databases from either local path or Hugging Face cache."""
84
+ tables_dir = get_db_path()
85
+ if not tables_dir:
86
+ return []
87
+
88
+ files = os.listdir(tables_dir)
89
+ print(f"All files found: {files}")
90
+
91
+ # Include all files except .md files
92
+ databases = [f for f in files if not f.endswith(".md")]
93
+ print(f"Database files: {databases}")
94
+
95
+ return databases
96
+
97
+
98
+ def query_db(query, is_sql, limit=None, wrap=False):
99
+ global db
100
+ if db is None:
101
+ return pd.DataFrame({"Error": ["Please load a database first."]})
102
+
103
+ try:
104
+ with sqlite3.connect(db.db_path) as conn:
105
+ cursor = conn.cursor()
106
+
107
+ query = " ".join(query.strip().split("\n")).rstrip(";")
108
+
109
+ if limit is not None:
110
+ if "LIMIT" in query.upper():
111
+ # Replace existing LIMIT clause
112
+ query = re.sub(
113
+ r"LIMIT\s+\d+", f"LIMIT {limit}", query, flags=re.IGNORECASE
114
+ )
115
+ else:
116
+ query += f" LIMIT {limit}"
117
+
118
+ cursor.execute(query)
119
+
120
+ column_names = [description[0] for description in cursor.description]
121
+
122
+ results = cursor.fetchall()
123
+
124
+ df = pd.DataFrame(results, columns=column_names)
125
+
126
+ for column in df.columns:
127
+ if df[column].dtype == "object":
128
+ df[column] = df[column].apply(
129
+ lambda x: (
130
+ format_url(x)
131
+ if column == "url"
132
+ else truncate_or_wrap_text(x, wrap=wrap)
133
+ )
134
+ )
135
+
136
+ return df
137
+
138
+ except sqlite3.Error as e:
139
+ return pd.DataFrame({"Error": [f"Database error: {str(e)}"]})
140
+ except Exception as e:
141
+ return pd.DataFrame({"Error": [f"An unexpected error occurred: {str(e)}"]})
142
+
143
+
144
+ def generate_concept_cooccurrence_graph(db_path, tag_type=None):
145
+ conn = sqlite3.connect(db_path)
146
+
147
+ query = COOCCURRENCE_QUERY
148
+ if tag_type and tag_type != "All":
149
+ query = query.replace(
150
+ "WHERE p1.tag_type = p2.tag_type",
151
+ f"WHERE p1.tag_type = p2.tag_type AND p1.tag_type = '{tag_type}'",
152
+ )
153
+
154
+ df = pd.read_sql_query(query, conn)
155
+ conn.close()
156
+
157
+ G = nx.from_pandas_edgelist(df, "concept1", "concept2", "co_occurrences")
158
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
159
+
160
+ edge_trace = go.Scatter(
161
+ x=[], y=[], line=dict(width=0.5, color="#888"), hoverinfo="none", mode="lines"
162
+ )
163
+
164
+ node_trace = go.Scatter(
165
+ x=[],
166
+ y=[],
167
+ mode="markers",
168
+ hoverinfo="text",
169
+ marker=dict(
170
+ showscale=True,
171
+ colorscale="YlGnBu",
172
+ size=10,
173
+ colorbar=dict(
174
+ thickness=15,
175
+ title="Node Connections",
176
+ xanchor="left",
177
+ titleside="right",
178
+ ),
179
+ ),
180
+ )
181
+
182
+ def update_traces(selected_node=None, depth=0):
183
+ nonlocal edge_trace, node_trace
184
+
185
+ if selected_node and depth > 0:
186
+ nodes_to_show = set([selected_node])
187
+ frontier = set([selected_node])
188
+ for _ in range(depth):
189
+ new_frontier = set()
190
+ for node in frontier:
191
+ new_frontier.update(G.neighbors(node))
192
+ nodes_to_show.update(new_frontier)
193
+ frontier = new_frontier
194
+ sub_G = G.subgraph(nodes_to_show)
195
+ else:
196
+ sub_G = G
197
+
198
+ edge_x, edge_y = [], []
199
+ for edge in sub_G.edges():
200
+ x0, y0 = pos[edge[0]]
201
+ x1, y1 = pos[edge[1]]
202
+ edge_x.extend([x0, x1, None])
203
+ edge_y.extend([y0, y1, None])
204
+
205
+ edge_trace.x = edge_x
206
+ edge_trace.y = edge_y
207
+
208
+ node_x, node_y = [], []
209
+ for node in sub_G.nodes():
210
+ x, y = pos[node]
211
+ node_x.append(x)
212
+ node_y.append(y)
213
+
214
+ node_trace.x = node_x
215
+ node_trace.y = node_y
216
+
217
+ node_adjacencies = []
218
+ node_text = []
219
+ for node in sub_G.nodes():
220
+ adjacencies = list(G.adj[node])
221
+ node_adjacencies.append(len(adjacencies))
222
+ node_text.append(f"{node}<br># of connections: {len(adjacencies)}")
223
+
224
+ node_trace.marker.color = node_adjacencies
225
+ node_trace.text = node_text
226
+
227
+ update_traces()
228
+
229
+ fig = go.Figure(
230
+ data=[edge_trace, node_trace],
231
+ layout=go.Layout(
232
+ title=f'Concept Co-occurrence Network {f"({tag_type})" if tag_type and tag_type != "All" else ""}',
233
+ titlefont_size=16,
234
+ showlegend=False,
235
+ hovermode="closest",
236
+ margin=dict(b=20, l=5, r=5, t=40),
237
+ annotations=[
238
+ dict(
239
+ text="",
240
+ showarrow=False,
241
+ xref="paper",
242
+ yref="paper",
243
+ x=0.005,
244
+ y=-0.002,
245
+ )
246
+ ],
247
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
248
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
249
+ ),
250
+ )
251
+
252
+ fig.update_layout(
253
+ updatemenus=[
254
+ dict(
255
+ type="buttons",
256
+ direction="left",
257
+ buttons=[
258
+ dict(
259
+ args=[{"visible": [True, True]}],
260
+ label="Full Graph",
261
+ method="update",
262
+ ),
263
+ dict(
264
+ args=[
265
+ {
266
+ "visible": [True, True],
267
+ "xaxis.range": [-1, 1],
268
+ "yaxis.range": [-1, 1],
269
+ }
270
+ ],
271
+ label="Core View",
272
+ method="relayout",
273
+ ),
274
+ dict(
275
+ args=[
276
+ {
277
+ "visible": [True, True],
278
+ "xaxis.range": [-0.2, 0.2],
279
+ "yaxis.range": [-0.2, 0.2],
280
+ }
281
+ ],
282
+ label="Detailed View",
283
+ method="relayout",
284
+ ),
285
+ ],
286
+ pad={"r": 10, "t": 10},
287
+ showactive=True,
288
+ x=0.11,
289
+ xanchor="left",
290
+ y=1.1,
291
+ yanchor="top",
292
+ ),
293
+ ]
294
+ )
295
+
296
+ return fig, G, pos, update_traces
297
+
298
+
299
+ def load_database_with_graphs(db_name):
300
+ """Load database from either local path or Hugging Face cache."""
301
+ global db
302
+ tables_dir = get_db_path()
303
+ if not tables_dir:
304
+ return f"No database directory found.", None
305
+
306
+ db_path = os.path.join(tables_dir, db_name)
307
+ if not os.path.exists(db_path):
308
+ return f"Database {db_name} does not exist.", None
309
+
310
+ db = ArxivDatabase(db_path)
311
+ db.init_db()
312
+
313
+ if db.is_db_empty:
314
+ return (
315
+ f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
316
+ None,
317
+ )
318
+
319
+ graph, _, _, _ = generate_concept_cooccurrence_graph(db_path)
320
+ return f"Database loaded from {db_path}", graph
321
+
322
+
323
+ css = """
324
+ #selected-query {
325
+ max-height: 100px;
326
+ overflow-y: auto;
327
+ white-space: pre-wrap;
328
+ word-break: break-word;
329
+ }
330
+ """
331
+
332
+
333
+ def create_demo():
334
+ with gr.Blocks() as demo:
335
+ gr.Markdown("# ArXiv Database Query Interface")
336
+
337
+ with gr.Row():
338
+ db_dropdown = gr.Dropdown(
339
+ choices=get_available_databases(),
340
+ label="Select Database",
341
+ value=get_available_databases(),
342
+ )
343
+ # load_db_btn = gr.Button("Load Database", size="sm")
344
+ status = gr.Textbox(label="Status")
345
+
346
+ with gr.Row():
347
+ graph_output = gr.Plot(label="Concept Co-occurrence Graph")
348
+
349
+ with gr.Row():
350
+ tag_type_dropdown = gr.Dropdown(
351
+ choices=[
352
+ "All",
353
+ "model",
354
+ "task",
355
+ "dataset",
356
+ "field",
357
+ "modality",
358
+ "method",
359
+ "object",
360
+ "property",
361
+ "instrument",
362
+ ],
363
+ label="Select Tag Type",
364
+ value="All",
365
+ )
366
+ highlight_input = gr.Textbox(label="Highlight Concepts (comma-separated)")
367
+
368
+ with gr.Row():
369
+ node_dropdown = gr.Dropdown(label="Select Node", choices=[])
370
+ depth_slider = gr.Slider(
371
+ minimum=0, maximum=5, step=1, value=0, label="Connection Depth"
372
+ )
373
+ update_graph_button = gr.Button("Update Graph")
374
+
375
+ with gr.Row():
376
+ wrap_checkbox = gr.Checkbox(label="Wrap long text", value=False)
377
+ canned_query_dropdown = gr.Dropdown(
378
+ choices=[q[0] for q in canned_queries], label="Select Query", scale=3
379
+ )
380
+ limit_input = gr.Number(
381
+ label="Limit", value=10000, step=1, minimum=1, scale=1
382
+ )
383
+ selected_query = gr.Textbox(
384
+ label="Selected Query",
385
+ interactive=False,
386
+ scale=2,
387
+ show_label=True,
388
+ show_copy_button=True,
389
+ elem_id="selected-query",
390
+ )
391
+ canned_query_submit = gr.Button("Submit Query", size="sm", scale=1)
392
+
393
+ with gr.Row():
394
+ sql_input = gr.Textbox(label="Custom SQL Query", lines=3, scale=4)
395
+ sql_submit = gr.Button("Submit Custom SQL", size="sm", scale=1)
396
+
397
+ # with gr.Row():
398
+ # nl_query_input = gr.Textbox(
399
+ # label="Natural Language Query", lines=2, scale=4
400
+ # )
401
+ # nl_query_submit = gr.Button("Convert to SQL", size="sm", scale=1)
402
+
403
+ output = gr.DataFrame(label="Results", wrap=True)
404
+
405
+ with gr.Row():
406
+ copy_button = gr.Button("Copy as Markdown")
407
+ download_button = gr.Button("Download as CSV")
408
+
409
+ def debounced_update_graph(
410
+ db_name, tag_type, highlight_concepts, selected_node, depth
411
+ ):
412
+ global last_update_time
413
+
414
+ current_time = time.time()
415
+ if current_time - last_update_time < update_delay:
416
+ return None, [] # Return early if not enough time has passed
417
+
418
+ last_update_time = current_time
419
+
420
+ if not db_name:
421
+ return None, []
422
+
423
+ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
424
+ db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
425
+ fig, G, pos, update_traces = generate_concept_cooccurrence_graph(
426
+ db_path, tag_type
427
+ )
428
+
429
+ if isinstance(selected_node, list):
430
+ selected_node = selected_node[0] if selected_node else None
431
+
432
+ highlight_nodes = (
433
+ [node.strip() for node in highlight_concepts.split(",")]
434
+ if highlight_concepts
435
+ else []
436
+ )
437
+ primary_node = highlight_nodes[0] if highlight_nodes else None
438
+
439
+ if primary_node and primary_node in G.nodes():
440
+ # Apply node selection and depth filter
441
+ nodes_to_show = set([primary_node])
442
+ if depth > 0:
443
+ frontier = set([primary_node])
444
+ for _ in range(depth):
445
+ new_frontier = set()
446
+ for node in frontier:
447
+ new_frontier.update(G.neighbors(node))
448
+ nodes_to_show.update(new_frontier)
449
+ frontier = new_frontier
450
+
451
+ sub_G = G.subgraph(nodes_to_show)
452
+
453
+ # Update traces with the filtered graph
454
+ edge_x, edge_y = [], []
455
+ for edge in sub_G.edges():
456
+ x0, y0 = pos[edge[0]]
457
+ x1, y1 = pos[edge[1]]
458
+ edge_x.extend([x0, x1, None])
459
+ edge_y.extend([y0, y1, None])
460
+
461
+ fig.data[0].x = edge_x
462
+ fig.data[0].y = edge_y
463
+
464
+ node_x, node_y = [], []
465
+ for node in sub_G.nodes():
466
+ x, y = pos[node]
467
+ node_x.append(x)
468
+ node_y.append(y)
469
+
470
+ fig.data[1].x = node_x
471
+ fig.data[1].y = node_y
472
+
473
+ # Color nodes based on their distance from the primary node and highlight status
474
+ node_colors = []
475
+ node_sizes = []
476
+ for node in sub_G.nodes():
477
+ if node in highlight_nodes:
478
+ node_colors.append(
479
+ "rgba(255,0,0,1)"
480
+ ) # Red for highlighted nodes
481
+ node_sizes.append(15)
482
+ else:
483
+ distance = nx.shortest_path_length(
484
+ sub_G, source=primary_node, target=node
485
+ )
486
+ intensity = max(0, 1 - (distance / (depth + 1)))
487
+ node_colors.append(f"rgba(0,0,255,{intensity})")
488
+ node_sizes.append(10)
489
+
490
+ fig.data[1].marker.color = node_colors
491
+ fig.data[1].marker.size = node_sizes
492
+
493
+ # Update node text
494
+ node_text = [
495
+ f"{node}<br># of connections: {len(list(G.neighbors(node)))}"
496
+ for node in sub_G.nodes()
497
+ ]
498
+ fig.data[1].text = node_text
499
+
500
+ # Get connected nodes for dropdown
501
+ connected_nodes = sorted(list(G.neighbors(primary_node)))
502
+ else:
503
+ # If no primary node or it's not in the graph, show the full graph
504
+ connected_nodes = sorted(list(G.nodes()))
505
+
506
+ return fig, connected_nodes
507
+
508
+ def update_node_dropdown(highlight_concepts):
509
+ if not highlight_concepts or not db:
510
+ return gr.Dropdown(choices=[])
511
+
512
+ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
513
+ db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db.db_path)
514
+ _, G, _, _ = generate_concept_cooccurrence_graph(db_path)
515
+
516
+ primary_node = highlight_concepts.split(",")[0].strip()
517
+ if primary_node in G.nodes():
518
+ connected_nodes = sorted(list(G.neighbors(primary_node)))
519
+ return gr.Dropdown(choices=connected_nodes)
520
+ else:
521
+ return gr.Dropdown(choices=[])
522
+
523
+ def update_selected_query(query_description):
524
+ for desc, sql in canned_queries:
525
+ if desc == query_description:
526
+ return sql
527
+ return ""
528
+
529
+ def submit_canned_query(query_description, limit, wrap):
530
+ for desc, sql in canned_queries:
531
+ if desc == query_description:
532
+ return query_db(sql, True, limit, wrap)
533
+ return pd.DataFrame({"Error": ["Selected query not found."]})
534
+
535
+ def copy_as_markdown(df):
536
+ return df.to_markdown()
537
+
538
+ def download_as_csv(df):
539
+ if df is None or df.empty:
540
+ return None
541
+
542
+ with tempfile.NamedTemporaryFile(
543
+ mode="w", delete=False, suffix=".csv"
544
+ ) as temp_file:
545
+ df.to_csv(temp_file.name, index=False)
546
+ temp_file_path = temp_file.name
547
+
548
+ return temp_file_path
549
+
550
+ # def nl_to_sql(nl_query):
551
+ # # Placeholder function for natural language to SQL conversion
552
+ # return f"SELECT * FROM papers WHERE abstract LIKE '%{nl_query}%' LIMIT 10;"
553
+
554
+ db_dropdown.change(
555
+ load_database_with_graphs,
556
+ inputs=[db_dropdown],
557
+ outputs=[status, graph_output],
558
+ )
559
+
560
+ # db_dropdown.change(
561
+ # debounced_update_graph,
562
+ # inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
563
+ # outputs=[graph_output, node_dropdown],
564
+ # )
565
+
566
+ tag_type_dropdown.change(
567
+ debounced_update_graph,
568
+ inputs=[
569
+ db_dropdown,
570
+ tag_type_dropdown,
571
+ highlight_input,
572
+ node_dropdown,
573
+ depth_slider,
574
+ ],
575
+ outputs=[graph_output, node_dropdown],
576
+ )
577
+
578
+ highlight_input.change(
579
+ update_node_dropdown,
580
+ inputs=[highlight_input],
581
+ outputs=[node_dropdown],
582
+ )
583
+ # node_dropdown.change(
584
+ # debounced_update_graph,
585
+ # inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
586
+ # outputs=[graph_output, node_dropdown],
587
+ # )
588
+
589
+ # depth_slider.change(
590
+ # debounced_update_graph,
591
+ # inputs=[db_dropdown, tag_type_dropdown, highlight_input, node_dropdown, depth_slider],
592
+ # outputs=[graph_output, node_dropdown],
593
+ # )
594
+ update_graph_button.click(
595
+ debounced_update_graph,
596
+ inputs=[
597
+ db_dropdown,
598
+ tag_type_dropdown,
599
+ highlight_input,
600
+ node_dropdown,
601
+ depth_slider,
602
+ ],
603
+ outputs=[graph_output, node_dropdown],
604
+ )
605
+ canned_query_dropdown.change(
606
+ update_selected_query,
607
+ inputs=[canned_query_dropdown],
608
+ outputs=[selected_query],
609
+ )
610
+ canned_query_submit.click(
611
+ submit_canned_query,
612
+ inputs=[canned_query_dropdown, limit_input, wrap_checkbox],
613
+ outputs=output,
614
+ )
615
+ sql_submit.click(
616
+ query_db,
617
+ inputs=[sql_input, gr.Checkbox(value=True), limit_input, wrap_checkbox],
618
+ outputs=output,
619
+ )
620
+ copy_button.click(
621
+ copy_as_markdown,
622
+ inputs=[output],
623
+ outputs=[gr.Textbox(label="Markdown Output", show_copy_button=True)],
624
+ )
625
+ download_button.click(
626
+ download_as_csv, inputs=[output], outputs=[gr.File(label="CSV Output")]
627
+ )
628
+ # nl_query_submit.click(nl_to_sql, inputs=[nl_query_input], outputs=[sql_input])
629
+
630
+ return demo
631
+
632
+
633
+ demo = create_demo()
634
+
635
+ def close_db():
636
+ global db
637
+ if db is not None:
638
+ db.close()
639
+ db = None
640
+
641
+
642
+ def launch():
643
+ print("Launching Gradio app...", flush=True)
644
+ shared_demo = demo.launch(share=True, prevent_thread_lock=True)
645
+
646
+ if isinstance(shared_demo, tuple):
647
+ if len(shared_demo) >= 2:
648
+ local_url, share_url = shared_demo[:2]
649
+ else:
650
+ local_url, share_url = shared_demo[0], "N/A"
651
+ else:
652
+ local_url = getattr(shared_demo, "local_url", "N/A")
653
+ share_url = getattr(shared_demo, "share_url", "N/A")
654
+
655
+ print(f"Local URL: {local_url}", flush=True)
656
+ print(f"Shareable link: {share_url}", flush=True)
657
+
658
+ print(
659
+ "Gradio app launched.",
660
+ flush=True,
661
+ )
662
+
663
+ # Keep the script running
664
+ demo.block_thread()
665
+
666
+
667
+ if __name__ == "__main__":
668
+ launch()
669
+
670
+ # Mount the Gradio app
671
+ # app = mount_gradio_app(app, demo, path="/")
672
+
673
+ # print(f"Shareable link: {demo.share_url}")
674
+
675
+ # @app.exception_handler(Exception)
676
+ # async def exception_handler(request: Request, exc: Exception):
677
+ # print(f"An error occurred: {str(exc)}")
678
+ # return {"error": str(exc)}
679
+
680
+ # @contextmanager
681
+ # def get_db_connection():
682
+ # global db
683
+ # conn = db.conn.cursor().connection
684
+ # try:
685
+ # yield conn
686
+ # finally:
687
+ # conn.close()
688
+
689
+ # @app.on_event("startup")
690
+ # async def startup_event():
691
+ # global db
692
+ # ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
693
+ # db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, get_available_databases()[0]) # Use the first available database
694
+ # db = ArxivDatabase(db_path)
695
+ # db.init_db()
696
+
697
+ # @app.on_event("shutdown")
698
+ # async def shutdown_event():
699
+ # if db is not None:
700
+ # db.close()
701
+
702
+
703
+ # if __name__ == "__main__":
704
+ # uvicorn.run(app, host="0.0.0.0", port=7860)
scripts/run_db_interface_basic.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import networkx as nx
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ import re
8
+ import sys
9
+ import sqlite3
10
+ import time
11
+ import uvicorn
12
+
13
+ from fastapi import FastAPI, Request
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from gradio.routes import mount_gradio_app
16
+ from plotly.subplots import make_subplots
17
+ from tabulate import tabulate
18
+ from typing import Optional
19
+
20
+
21
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
22
+ if ROOT_DIR not in sys.path:
23
+ sys.path.insert(0, ROOT_DIR)
24
+
25
+ from scripts.create_db import ArxivDatabase
26
+ from config import (
27
+ DEFAULT_TABLES_DIR,
28
+ DEFAULT_INTERFACE_MODEL_ID,
29
+ COOCCURRENCE_QUERY,
30
+ canned_queries,
31
+ )
32
+
33
+ app = FastAPI()
34
+
35
+ # Add CORS middleware
36
+ app.add_middleware(
37
+ CORSMiddleware,
38
+ allow_origins=["*"],
39
+ allow_credentials=True,
40
+ allow_methods=["*"],
41
+ allow_headers=["*"],
42
+ )
43
+
44
+ db: Optional[ArxivDatabase] = None
45
+
46
+
47
+ def truncate_or_wrap_text(text, max_length=50, wrap=False):
48
+ """Truncate text to a maximum length, adding ellipsis if truncated, or wrap if specified."""
49
+ if wrap:
50
+ return "\n".join(
51
+ text[i : i + max_length] for i in range(0, len(text), max_length)
52
+ )
53
+ return text[:max_length] + "..." if len(text) > max_length else text
54
+
55
+
56
+ def format_url(url):
57
+ """Format URL to be more compact in the table."""
58
+ return url.split("/")[-1] if url.startswith("http") else url
59
+
60
+
61
+ def get_available_databases():
62
+ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
63
+ tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
64
+ return [f for f in os.listdir(tables_dir) if f.endswith(".db")]
65
+
66
+
67
+ def query_db(query, is_sql, limit=None, wrap=False):
68
+ global db
69
+ if db is None:
70
+ return pd.DataFrame({"Error": ["Please load a database first."]})
71
+
72
+ try:
73
+ cursor = db.conn.cursor()
74
+
75
+ query = " ".join(query.strip().split("\n")).rstrip(";")
76
+
77
+ if limit is not None:
78
+ if "LIMIT" in query.upper():
79
+ # Replace existing LIMIT clause
80
+ query = re.sub(
81
+ r"LIMIT\s+\d+", f"LIMIT {limit}", query, flags=re.IGNORECASE
82
+ )
83
+ else:
84
+ query += f" LIMIT {limit}"
85
+
86
+ cursor.execute(query)
87
+
88
+ column_names = [description[0] for description in cursor.description]
89
+
90
+ results = cursor.fetchall()
91
+
92
+ df = pd.DataFrame(results, columns=column_names)
93
+
94
+ for column in df.columns:
95
+ if df[column].dtype == "object":
96
+ df[column] = df[column].apply(
97
+ lambda x: (
98
+ format_url(x)
99
+ if column == "url"
100
+ else truncate_or_wrap_text(x, wrap=wrap)
101
+ )
102
+ )
103
+
104
+ return df
105
+
106
+ except sqlite3.Error as e:
107
+ return pd.DataFrame({"Error": [f"Database error: {str(e)}"]})
108
+ except Exception as e:
109
+ return pd.DataFrame({"Error": [f"An unexpected error occurred: {str(e)}"]})
110
+
111
+
112
+ def generate_concept_cooccurrence_graph(db_path):
113
+ conn = sqlite3.connect(db_path)
114
+ df = pd.read_sql_query(COOCCURRENCE_QUERY, conn)
115
+ conn.close()
116
+
117
+ G = nx.from_pandas_edgelist(df, "concept1", "concept2", "co_occurrences")
118
+ pos = nx.spring_layout(G)
119
+
120
+ edge_x = []
121
+ edge_y = []
122
+ for edge in G.edges():
123
+ x0, y0 = pos[edge[0]]
124
+ x1, y1 = pos[edge[1]]
125
+ edge_x.extend([x0, x1, None])
126
+ edge_y.extend([y0, y1, None])
127
+
128
+ edge_trace = go.Scatter(
129
+ x=edge_x,
130
+ y=edge_y,
131
+ line=dict(width=0.5, color="#888"),
132
+ hoverinfo="none",
133
+ mode="lines",
134
+ )
135
+
136
+ node_x = [pos[node][0] for node in G.nodes()]
137
+ node_y = [pos[node][1] for node in G.nodes()]
138
+
139
+ node_trace = go.Scatter(
140
+ x=node_x,
141
+ y=node_y,
142
+ mode="markers",
143
+ hoverinfo="text",
144
+ marker=dict(
145
+ showscale=True,
146
+ colorscale="YlGnBu",
147
+ size=10,
148
+ colorbar=dict(
149
+ thickness=15,
150
+ title="Node Connections",
151
+ xanchor="left",
152
+ titleside="right",
153
+ ),
154
+ ),
155
+ )
156
+
157
+ node_adjacencies = []
158
+ node_text = []
159
+ for node, adjacencies in G.adjacency():
160
+ node_adjacencies.append(len(adjacencies))
161
+ node_text.append(f"{node}<br># of connections: {len(adjacencies)}")
162
+
163
+ node_trace.marker.color = node_adjacencies
164
+ node_trace.text = node_text
165
+
166
+ fig = go.Figure(
167
+ data=[edge_trace, node_trace],
168
+ layout=go.Layout(
169
+ title="Concept Co-occurrence Network",
170
+ titlefont_size=16,
171
+ showlegend=False,
172
+ hovermode="closest",
173
+ margin=dict(b=20, l=5, r=5, t=40),
174
+ annotations=[
175
+ dict(
176
+ text="",
177
+ showarrow=False,
178
+ xref="paper",
179
+ yref="paper",
180
+ x=0.005,
181
+ y=-0.002,
182
+ )
183
+ ],
184
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
185
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
186
+ ),
187
+ )
188
+ return fig
189
+
190
+
191
+ # def load_database_with_graphs(db_name):
192
+ # global db
193
+ # ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
194
+ # db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
195
+ # if not os.path.exists(db_path):
196
+ # return f"Database {db_name} does not exist.", None
197
+ # db = ArxivDatabase(db_path)
198
+ # db.init_db()
199
+ # if db.is_db_empty:
200
+ # return (
201
+ # f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
202
+ # None,
203
+ # )
204
+
205
+ # # Generate graph
206
+ # graph = generate_concept_cooccurrence_graph(db_path)
207
+
208
+ # return f"Database loaded from {db_path}", graph
209
+
210
+
211
+ def load_database_with_graphs(db_name):
212
+ global db
213
+ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
214
+ db_path = os.path.join(ROOT, DEFAULT_TABLES_DIR, db_name)
215
+ if not os.path.exists(db_path):
216
+ return f"Database {db_name} does not exist.", None
217
+
218
+ if db is None or db.db_path != db_path:
219
+ db = ArxivDatabase(db_path)
220
+ db.init_db()
221
+
222
+ if db.is_db_empty:
223
+ return (
224
+ f"Database loaded from {db_path}, but it is empty. Please populate it with data.",
225
+ None,
226
+ )
227
+
228
+ graph = generate_concept_cooccurrence_graph(db_path)
229
+ return f"Database loaded from {db_path}", graph
230
+
231
+
232
+ css = """
233
+ #selected-query {
234
+ max-height: 100px;
235
+ overflow-y: auto;
236
+ white-space: pre-wrap;
237
+ word-break: break-word;
238
+ }
239
+ """
240
+
241
+
242
+ def create_demo():
243
+ with gr.Blocks(css=css) as demo:
244
+ gr.Markdown("# ArXiv Database Query Interface")
245
+
246
+ with gr.Row():
247
+ db_dropdown = gr.Dropdown(
248
+ choices=get_available_databases(), label="Select Database"
249
+ )
250
+ load_db_btn = gr.Button("Load Database", size="sm")
251
+ status = gr.Textbox(label="Status")
252
+
253
+ with gr.Row():
254
+ graph_output = gr.Plot(label="Concept Co-occurrence Graph")
255
+
256
+ with gr.Row():
257
+ wrap_checkbox = gr.Checkbox(label="Wrap long text", value=False)
258
+ canned_query_dropdown = gr.Dropdown(
259
+ choices=[q[0] for q in canned_queries], label="Select Query", scale=3
260
+ )
261
+ limit_input = gr.Number(
262
+ label="Limit", value=10000, step=1, minimum=1, scale=1
263
+ )
264
+ selected_query = gr.Textbox(
265
+ label="Selected Query",
266
+ interactive=False,
267
+ scale=2,
268
+ show_label=True,
269
+ show_copy_button=True,
270
+ elem_id="selected-query",
271
+ )
272
+ canned_query_submit = gr.Button("Submit Query", size="sm", scale=1)
273
+
274
+ with gr.Row():
275
+ sql_input = gr.Textbox(label="Custom SQL Query", lines=3, scale=4)
276
+ sql_submit = gr.Button("Submit Custom SQL", size="sm", scale=1)
277
+
278
+ output = gr.DataFrame(label="Results", wrap=True)
279
+
280
+ def update_selected_query(query_description):
281
+ for desc, sql in canned_queries:
282
+ if desc == query_description:
283
+ return sql
284
+ return ""
285
+
286
+ def submit_canned_query(query_description, limit, wrap):
287
+ for desc, sql in canned_queries:
288
+ if desc == query_description:
289
+ return query_db(sql, True, limit, wrap)
290
+ return pd.DataFrame({"Error": ["Selected query not found."]})
291
+
292
+ load_db_btn.click(
293
+ load_database_with_graphs,
294
+ inputs=[db_dropdown],
295
+ outputs=[status, graph_output],
296
+ )
297
+ canned_query_dropdown.change(
298
+ update_selected_query,
299
+ inputs=[canned_query_dropdown],
300
+ outputs=[selected_query],
301
+ )
302
+ canned_query_submit.click(
303
+ submit_canned_query,
304
+ inputs=[canned_query_dropdown, limit_input, wrap_checkbox],
305
+ outputs=output,
306
+ )
307
+ sql_submit.click(
308
+ query_db,
309
+ inputs=[sql_input, gr.Checkbox(value=True), limit_input, wrap_checkbox],
310
+ outputs=output,
311
+ )
312
+
313
+ return demo
314
+
315
+
316
+ demo = create_demo()
317
+
318
+
319
+ def close_db():
320
+ global db
321
+ if db is not None:
322
+ db.close()
323
+ db = None
324
+
325
+
326
+ # def launch():
327
+ # print("Launching Gradio app...", flush=True)
328
+ # demo.launch(share=True)
329
+ # print(
330
+ # "Gradio app launched. If you don't see a URL above, there might be network restrictions.",
331
+ # flush=True,
332
+ # )
333
+
334
+ # close_db()
335
+
336
+ # if __name__ == "__main__":
337
+ # launch()
338
+
339
+ # Mount the Gradio app
340
+ app = mount_gradio_app(app, demo, path="/")
341
+
342
+
343
+ @app.exception_handler(Exception)
344
+ async def exception_handler(request: Request, exc: Exception):
345
+ print(f"An error occurred: {str(exc)}")
346
+ return {"error": str(exc)}
347
+
348
+
349
+ @app.on_event("startup")
350
+ async def startup_event():
351
+ # You can initialize the database here if needed
352
+ pass
353
+
354
+
355
+ @app.on_event("shutdown")
356
+ async def shutdown_event():
357
+ close_db()
358
+
359
+
360
+ if __name__ == "__main__":
361
+ uvicorn.run(app, host="0.0.0.0", port=7860)
scripts/run_db_interface_js.py ADDED
File without changes