VyLala commited on
Commit
f94920a
·
verified ·
1 Parent(s): 9ebbf94

Upload 55 files

Browse files
Files changed (9) hide show
  1. NER/html/extractHTML.py +363 -363
  2. app.py +0 -0
  3. better_offer.html +201 -201
  4. data_preprocess.py +876 -876
  5. model.py +0 -0
  6. mtdna_backend.py +1004 -1144
  7. mtdna_classifier.py +768 -768
  8. pipeline.py +0 -0
  9. smart_fallback.py +401 -401
NER/html/extractHTML.py CHANGED
@@ -1,364 +1,364 @@
1
- # reference: https://www.crummy.com/software/BeautifulSoup/bs4/doc/#for-html-documents
2
- from bs4 import BeautifulSoup
3
- import requests
4
- from DefaultPackages import openFile, saveFile
5
- from NER import cleanText
6
- import pandas as pd
7
- from lxml.etree import ParserError, XMLSyntaxError
8
- import aiohttp
9
- import asyncio
10
- class HTML():
11
- def __init__(self, htmlFile, htmlLink, htmlContent: str=None):
12
- self.htmlLink = htmlLink
13
- self.htmlFile = htmlFile
14
- self.htmlContent = htmlContent # NEW: store raw HTML if provided
15
- def fetch_crossref_metadata(self, doi):
16
- """Fetch metadata from CrossRef API for a given DOI."""
17
- try:
18
- url = f"https://api.crossref.org/works/{doi}"
19
- r = requests.get(url, timeout=10)
20
- if r.status_code == 200:
21
- return r.json().get("message", {})
22
- else:
23
- print(f"⚠️ CrossRef fetch failed ({r.status_code}) for DOI: {doi}")
24
- return {}
25
- except Exception as e:
26
- print(f"❌ CrossRef exception: {e}")
27
- return {}
28
- # def openHTMLFile(self):
29
- # headers = {
30
- # "User-Agent": (
31
- # "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
32
- # "AppleWebKit/537.36 (KHTML, like Gecko) "
33
- # "Chrome/114.0.0.0 Safari/537.36"
34
- # ),
35
- # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
36
- # "Referer": self.htmlLink,
37
- # "Connection": "keep-alive"
38
- # }
39
-
40
- # session = requests.Session()
41
- # session.headers.update(headers)
42
-
43
- # if self.htmlLink != "None":
44
- # try:
45
- # r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
46
- # if r.status_code != 200:
47
- # print(f"❌ HTML GET failed: {r.status_code} — {self.htmlLink}")
48
- # return BeautifulSoup("", 'html.parser')
49
- # soup = BeautifulSoup(r.content, 'html.parser')
50
- # except Exception as e:
51
- # print(f"❌ Exception fetching HTML: {e}")
52
- # return BeautifulSoup("", 'html.parser')
53
- # else:
54
- # with open(self.htmlFile) as fp:
55
- # soup = BeautifulSoup(fp, 'html.parser')
56
- # return soup
57
-
58
- def openHTMLFile(self):
59
- """Return a BeautifulSoup object from cached htmlContent, file, or requests."""
60
- # If raw HTML already provided (from async aiohttp), use it directly
61
- if self.htmlContent is not None:
62
- return BeautifulSoup(self.htmlContent, "html.parser")
63
-
64
- not_need_domain = ['https://broadinstitute.github.io/picard/',
65
- 'https://software.broadinstitute.org/gatk/best-practices/',
66
- 'https://www.ncbi.nlm.nih.gov/genbank/',
67
- 'https://www.mitomap.org/']
68
- if self.htmlLink in not_need_domain:
69
- return BeautifulSoup("", 'html.parser')
70
- headers = {
71
- "User-Agent": (
72
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
73
- "AppleWebKit/537.36 (KHTML, like Gecko) "
74
- "Chrome/114.0.0.0 Safari/537.36"
75
- ),
76
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
77
- "Accept-Language": "en-US,en;q=0.9",
78
- "Referer": "https://www.google.com/",
79
- #"Referer": self.htmlLink,
80
- "Connection": "keep-alive"
81
- }
82
-
83
- session = requests.Session()
84
- session.headers.update(headers)
85
- try:
86
- if self.htmlLink and self.htmlLink != "None":
87
- r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
88
- if r.status_code != 200 or not r.text.strip():
89
- print(f"❌ HTML GET failed ({r.status_code}) or empty page: {self.htmlLink}")
90
- return BeautifulSoup("", 'html.parser')
91
- soup = BeautifulSoup(r.content, 'html.parser')
92
- elif self.htmlFile:
93
- with open(self.htmlFile, encoding='utf-8') as fp:
94
- soup = BeautifulSoup(fp, 'html.parser')
95
- except (ParserError, XMLSyntaxError, OSError) as e:
96
- print(f"🚫 HTML parse error for {self.htmlLink}: {type(e).__name__}")
97
- return BeautifulSoup("", 'html.parser')
98
- except Exception as e:
99
- print(f"❌ General exception for {self.htmlLink}: {e}")
100
- return BeautifulSoup("", 'html.parser')
101
-
102
- return soup
103
-
104
- async def async_fetch_html(self):
105
- """Async fetch HTML content with aiohttp."""
106
- not_need_domain = [
107
- "https://broadinstitute.github.io/picard/",
108
- "https://software.broadinstitute.org/gatk/best-practices/",
109
- "https://www.ncbi.nlm.nih.gov/genbank/",
110
- "https://www.mitomap.org/",
111
- ]
112
- if self.htmlLink in not_need_domain:
113
- return "" # Skip domains we don't need
114
-
115
- headers = {
116
- "User-Agent": (
117
- "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
118
- "AppleWebKit/537.36 (KHTML, like Gecko) "
119
- "Chrome/114.0.0.0 Safari/537.36"
120
- ),
121
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
122
- "Accept-Language": "en-US,en;q=0.9",
123
- "Referer": "https://www.google.com/",
124
- "Connection": "keep-alive",
125
- }
126
-
127
- try:
128
- async with aiohttp.ClientSession(headers=headers) as session:
129
- async with session.get(self.htmlLink, timeout=15) as resp:
130
- if resp.status != 200:
131
- print(f"❌ HTML GET failed ({resp.status}) — {self.htmlLink}")
132
- return ""
133
- return await resp.text()
134
- except Exception as e:
135
- print(f"❌ Async fetch failed for {self.htmlLink}: {e}")
136
- return ""
137
-
138
- @classmethod
139
- async def bulk_fetch(cls, links: list[str]):
140
- """Fetch multiple links concurrently, return list of HTML() objects with htmlContent filled."""
141
- tasks = [cls("", link).async_fetch_html() for link in links]
142
- results = await asyncio.gather(*tasks, return_exceptions=True)
143
-
144
- out = []
145
- for link, content in zip(links, results):
146
- if isinstance(content, Exception):
147
- print(f"⚠️ Exception while fetching {link}: {content}")
148
- out.append(cls("", link, htmlContent=""))
149
- else:
150
- out.append(cls("", link, htmlContent=content))
151
- return out
152
-
153
-
154
- def getText(self):
155
- try:
156
- soup = self.openHTMLFile()
157
- s = soup.find_all("html")
158
- text = ""
159
- if s:
160
- for t in range(len(s)):
161
- text = s[t].get_text()
162
- cl = cleanText.cleanGenText()
163
- text = cl.removeExtraSpaceBetweenWords(text)
164
- return text
165
- except:
166
- print("failed get text from html")
167
- return ""
168
-
169
- async def async_getListSection(self, scienceDirect=None):
170
- try:
171
- json = {}
172
- textJson, textHTML = "", ""
173
-
174
- # Use preloaded HTML (fast path)
175
- soup = self.openHTMLFile()
176
- h2_tags = soup.find_all('h2')
177
- for idx, h2 in enumerate(h2_tags):
178
- section_title = h2.get_text(strip=True)
179
- json.setdefault(section_title, [])
180
- next_h2 = h2_tags[idx+1] if idx+1 < len(h2_tags) else None
181
- for p in h2.find_all_next("p"):
182
- if next_h2 and p == next_h2:
183
- break
184
- json[section_title].append(p.get_text(strip=True))
185
-
186
- # If no sections or explicitly ScienceDirect
187
- if scienceDirect is not None or len(json) == 0:
188
- print("async fetching ScienceDirect metadata...")
189
- api_key = "d0f25e6ae2b275e0d2b68e0e98f68d70"
190
- doi = self.htmlLink.split("https://doi.org/")[-1]
191
- base_url = f"https://api.elsevier.com/content/article/doi/{doi}"
192
- headers = {"Accept": "application/json", "X-ELS-APIKey": api_key}
193
-
194
- async with aiohttp.ClientSession() as session:
195
- async with session.get(base_url, headers=headers, timeout=15) as resp:
196
- if resp.status == 200:
197
- data = await resp.json()
198
- if isinstance(data, dict):
199
- json["fullText"] = data
200
-
201
- # Merge text
202
- textJson = self.mergeTextInJson(json)
203
- textHTML = self.getText()
204
- return textHTML if len(textHTML) > len(textJson) else textJson
205
-
206
- except Exception as e:
207
- print("❌ async_getListSection failed:", e)
208
- return ""
209
-
210
- def getListSection(self, scienceDirect=None):
211
- try:
212
- json = {}
213
- text = ""
214
- textJson, textHTML = "",""
215
- if scienceDirect == None:
216
- # soup = self.openHTMLFile()
217
- # # get list of section
218
- # json = {}
219
- # for h2Pos in range(len(soup.find_all('h2'))):
220
- # if soup.find_all('h2')[h2Pos].text not in json:
221
- # json[soup.find_all('h2')[h2Pos].text] = []
222
- # if h2Pos + 1 < len(soup.find_all('h2')):
223
- # content = soup.find_all('h2')[h2Pos].find_next("p")
224
- # nexth2Content = soup.find_all('h2')[h2Pos+1].find_next("p")
225
- # while content.text != nexth2Content.text:
226
- # json[soup.find_all('h2')[h2Pos].text].append(content.text)
227
- # content = content.find_next("p")
228
- # else:
229
- # content = soup.find_all('h2')[h2Pos].find_all_next("p",string=True)
230
- # json[soup.find_all('h2')[h2Pos].text] = list(i.text for i in content)
231
-
232
- soup = self.openHTMLFile()
233
- h2_tags = soup.find_all('h2')
234
- json = {}
235
-
236
- for idx, h2 in enumerate(h2_tags):
237
- section_title = h2.get_text(strip=True)
238
- json.setdefault(section_title, [])
239
-
240
- # Get paragraphs until next H2
241
- next_h2 = h2_tags[idx+1] if idx+1 < len(h2_tags) else None
242
- for p in h2.find_all_next("p"):
243
- if next_h2 and p == next_h2:
244
- break
245
- json[section_title].append(p.get_text(strip=True))
246
- # format
247
- '''json = {'Abstract':[], 'Introduction':[], 'Methods'[],
248
- 'Results':[], 'Discussion':[], 'References':[],
249
- 'Acknowledgements':[], 'Author information':[], 'Ethics declarations':[],
250
- 'Additional information':[], 'Electronic supplementary material':[],
251
- 'Rights and permissions':[], 'About this article':[], 'Search':[], 'Navigation':[]}'''
252
- if scienceDirect!= None or len(json)==0:
253
- # Replace with your actual Elsevier API key
254
- api_key = os.environ["SCIENCE_DIRECT_API"]
255
- # ScienceDirect article DOI or PI (Example DOI)
256
- doi = self.htmlLink.split("https://doi.org/")[-1] #"10.1016/j.ajhg.2011.01.009"
257
- # Base URL for the Elsevier API
258
- base_url = "https://api.elsevier.com/content/article/doi/"
259
- # Set headers with API key
260
- headers = {
261
- "Accept": "application/json",
262
- "X-ELS-APIKey": api_key
263
- }
264
- # Make the API request
265
- response = requests.get(base_url + doi, headers=headers)
266
- # Check if the request was successful
267
- if response.status_code == 200:
268
- data = response.json()
269
- supp_data = data["full-text-retrieval-response"]#["coredata"]["link"]
270
- # if "originalText" in list(supp_data.keys()):
271
- # if type(supp_data["originalText"])==str:
272
- # json["originalText"] = [supp_data["originalText"]]
273
- # if type(supp_data["originalText"])==dict:
274
- # json["originalText"] = [supp_data["originalText"][key] for key in supp_data["originalText"]]
275
- # else:
276
- # if type(supp_data)==dict:
277
- # for key in supp_data:
278
- # json[key] = [supp_data[key]]
279
- if type(data)==dict:
280
- json["fullText"] = data
281
- textJson = self.mergeTextInJson(json)
282
- textHTML = self.getText()
283
- if len(textHTML) > len(textJson):
284
- text = textHTML
285
- else: text = textJson
286
- return text #json
287
- except:
288
- print("failed all")
289
- return ""
290
- def getReference(self):
291
- # get reference to collect more next data
292
- ref = []
293
- json = self.getListSection()
294
- for key in json["References"]:
295
- ct = cleanText.cleanGenText(key)
296
- cleanText, filteredWord = ct.cleanText()
297
- if cleanText not in ref:
298
- ref.append(cleanText)
299
- return ref
300
- def getSupMaterial(self):
301
- # check if there is material or not
302
- json = {}
303
- soup = self.openHTMLFile()
304
- for h2Pos in range(len(soup.find_all('h2'))):
305
- if "supplementary" in soup.find_all('h2')[h2Pos].text.lower() or "material" in soup.find_all('h2')[h2Pos].text.lower() or "additional" in soup.find_all('h2')[h2Pos].text.lower() or "support" in soup.find_all('h2')[h2Pos].text.lower():
306
- #print(soup.find_all('h2')[h2Pos].find_next("a").get("href"))
307
- link, output = [],[]
308
- if soup.find_all('h2')[h2Pos].text not in json:
309
- json[soup.find_all('h2')[h2Pos].text] = []
310
- for l in soup.find_all('h2')[h2Pos].find_all_next("a",href=True):
311
- link.append(l["href"])
312
- if h2Pos + 1 < len(soup.find_all('h2')):
313
- nexth2Link = soup.find_all('h2')[h2Pos+1].find_next("a",href=True)["href"]
314
- if nexth2Link in link:
315
- link = link[:link.index(nexth2Link)]
316
- # only take links having "https" in that
317
- for i in link:
318
- if "https" in i: output.append(i)
319
- json[soup.find_all('h2')[h2Pos].text].extend(output)
320
- return json
321
- def extractTable(self):
322
- soup = self.openHTMLFile()
323
- df = []
324
- if len(soup)>0:
325
- try:
326
- df = pd.read_html(str(soup))
327
- except ValueError:
328
- df = []
329
- print("No tables found in HTML file")
330
- return df
331
- def mergeTextInJson(self,jsonHTML):
332
- try:
333
- #cl = cleanText.cleanGenText()
334
- htmlText = ""
335
- if jsonHTML:
336
- # try:
337
- # for sec, entries in jsonHTML.items():
338
- # for i, entry in enumerate(entries):
339
- # # Only process if it's actually text
340
- # if isinstance(entry, str):
341
- # if entry.strip():
342
- # entry, filteredWord = cl.textPreprocessing(entry, keepPeriod=True)
343
- # else:
344
- # # Skip or convert dicts/lists to string if needed
345
- # entry = str(entry)
346
-
347
- # jsonHTML[sec][i] = entry
348
-
349
- # # Add spacing between sentences
350
- # if i - 1 >= 0 and jsonHTML[sec][i - 1] and jsonHTML[sec][i - 1][-1] != ".":
351
- # htmlText += ". "
352
- # htmlText += entry
353
-
354
- # # Add final period if needed
355
- # if entries and isinstance(entries[-1], str) and entries[-1] and entries[-1][-1] != ".":
356
- # htmlText += "."
357
- # htmlText += "\n\n"
358
- # except:
359
- htmlText += str(jsonHTML)
360
- return htmlText
361
- except:
362
- print("failed merge text in json")
363
- return ""
364
 
 
1
+ # reference: https://www.crummy.com/software/BeautifulSoup/bs4/doc/#for-html-documents
2
+ from bs4 import BeautifulSoup
3
+ import requests
4
+ from DefaultPackages import openFile, saveFile
5
+ from NER import cleanText
6
+ import pandas as pd
7
+ from lxml.etree import ParserError, XMLSyntaxError
8
+ import aiohttp
9
+ import asyncio
10
+ class HTML():
11
+ def __init__(self, htmlFile, htmlLink, htmlContent: str=None):
12
+ self.htmlLink = htmlLink
13
+ self.htmlFile = htmlFile
14
+ self.htmlContent = htmlContent # NEW: store raw HTML if provided
15
+ def fetch_crossref_metadata(self, doi):
16
+ """Fetch metadata from CrossRef API for a given DOI."""
17
+ try:
18
+ url = f"https://api.crossref.org/works/{doi}"
19
+ r = requests.get(url, timeout=10)
20
+ if r.status_code == 200:
21
+ return r.json().get("message", {})
22
+ else:
23
+ print(f"⚠️ CrossRef fetch failed ({r.status_code}) for DOI: {doi}")
24
+ return {}
25
+ except Exception as e:
26
+ print(f"❌ CrossRef exception: {e}")
27
+ return {}
28
+ # def openHTMLFile(self):
29
+ # headers = {
30
+ # "User-Agent": (
31
+ # "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
32
+ # "AppleWebKit/537.36 (KHTML, like Gecko) "
33
+ # "Chrome/114.0.0.0 Safari/537.36"
34
+ # ),
35
+ # "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
36
+ # "Referer": self.htmlLink,
37
+ # "Connection": "keep-alive"
38
+ # }
39
+
40
+ # session = requests.Session()
41
+ # session.headers.update(headers)
42
+
43
+ # if self.htmlLink != "None":
44
+ # try:
45
+ # r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
46
+ # if r.status_code != 200:
47
+ # print(f"❌ HTML GET failed: {r.status_code} — {self.htmlLink}")
48
+ # return BeautifulSoup("", 'html.parser')
49
+ # soup = BeautifulSoup(r.content, 'html.parser')
50
+ # except Exception as e:
51
+ # print(f"❌ Exception fetching HTML: {e}")
52
+ # return BeautifulSoup("", 'html.parser')
53
+ # else:
54
+ # with open(self.htmlFile) as fp:
55
+ # soup = BeautifulSoup(fp, 'html.parser')
56
+ # return soup
57
+
58
+ def openHTMLFile(self):
59
+ """Return a BeautifulSoup object from cached htmlContent, file, or requests."""
60
+ # If raw HTML already provided (from async aiohttp), use it directly
61
+ if self.htmlContent is not None:
62
+ return BeautifulSoup(self.htmlContent, "html.parser")
63
+
64
+ not_need_domain = ['https://broadinstitute.github.io/picard/',
65
+ 'https://software.broadinstitute.org/gatk/best-practices/',
66
+ 'https://www.ncbi.nlm.nih.gov/genbank/',
67
+ 'https://www.mitomap.org/']
68
+ if self.htmlLink in not_need_domain:
69
+ return BeautifulSoup("", 'html.parser')
70
+ headers = {
71
+ "User-Agent": (
72
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
73
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
74
+ "Chrome/114.0.0.0 Safari/537.36"
75
+ ),
76
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
77
+ "Accept-Language": "en-US,en;q=0.9",
78
+ "Referer": "https://www.google.com/",
79
+ #"Referer": self.htmlLink,
80
+ "Connection": "keep-alive"
81
+ }
82
+
83
+ session = requests.Session()
84
+ session.headers.update(headers)
85
+ try:
86
+ if self.htmlLink and self.htmlLink != "None":
87
+ r = session.get(self.htmlLink, allow_redirects=True, timeout=15)
88
+ if r.status_code != 200 or not r.text.strip():
89
+ print(f"❌ HTML GET failed ({r.status_code}) or empty page: {self.htmlLink}")
90
+ return BeautifulSoup("", 'html.parser')
91
+ soup = BeautifulSoup(r.content, 'html.parser')
92
+ elif self.htmlFile:
93
+ with open(self.htmlFile, encoding='utf-8') as fp:
94
+ soup = BeautifulSoup(fp, 'html.parser')
95
+ except (ParserError, XMLSyntaxError, OSError) as e:
96
+ print(f"🚫 HTML parse error for {self.htmlLink}: {type(e).__name__}")
97
+ return BeautifulSoup("", 'html.parser')
98
+ except Exception as e:
99
+ print(f"❌ General exception for {self.htmlLink}: {e}")
100
+ return BeautifulSoup("", 'html.parser')
101
+
102
+ return soup
103
+
104
+ async def async_fetch_html(self):
105
+ """Async fetch HTML content with aiohttp."""
106
+ not_need_domain = [
107
+ "https://broadinstitute.github.io/picard/",
108
+ "https://software.broadinstitute.org/gatk/best-practices/",
109
+ "https://www.ncbi.nlm.nih.gov/genbank/",
110
+ "https://www.mitomap.org/",
111
+ ]
112
+ if self.htmlLink in not_need_domain:
113
+ return "" # Skip domains we don't need
114
+
115
+ headers = {
116
+ "User-Agent": (
117
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
118
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
119
+ "Chrome/114.0.0.0 Safari/537.36"
120
+ ),
121
+ "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
122
+ "Accept-Language": "en-US,en;q=0.9",
123
+ "Referer": "https://www.google.com/",
124
+ "Connection": "keep-alive",
125
+ }
126
+
127
+ try:
128
+ async with aiohttp.ClientSession(headers=headers) as session:
129
+ async with session.get(self.htmlLink, timeout=15) as resp:
130
+ if resp.status != 200:
131
+ print(f"❌ HTML GET failed ({resp.status}) — {self.htmlLink}")
132
+ return ""
133
+ return await resp.text()
134
+ except Exception as e:
135
+ print(f"❌ Async fetch failed for {self.htmlLink}: {e}")
136
+ return ""
137
+
138
+ @classmethod
139
+ async def bulk_fetch(cls, links: list[str]):
140
+ """Fetch multiple links concurrently, return list of HTML() objects with htmlContent filled."""
141
+ tasks = [cls("", link).async_fetch_html() for link in links]
142
+ results = await asyncio.gather(*tasks, return_exceptions=True)
143
+
144
+ out = []
145
+ for link, content in zip(links, results):
146
+ if isinstance(content, Exception):
147
+ print(f"⚠️ Exception while fetching {link}: {content}")
148
+ out.append(cls("", link, htmlContent=""))
149
+ else:
150
+ out.append(cls("", link, htmlContent=content))
151
+ return out
152
+
153
+
154
+ def getText(self):
155
+ try:
156
+ soup = self.openHTMLFile()
157
+ s = soup.find_all("html")
158
+ text = ""
159
+ if s:
160
+ for t in range(len(s)):
161
+ text = s[t].get_text()
162
+ cl = cleanText.cleanGenText()
163
+ text = cl.removeExtraSpaceBetweenWords(text)
164
+ return text
165
+ except:
166
+ print("failed get text from html")
167
+ return ""
168
+
169
+ async def async_getListSection(self, scienceDirect=None):
170
+ try:
171
+ json = {}
172
+ textJson, textHTML = "", ""
173
+
174
+ # Use preloaded HTML (fast path)
175
+ soup = self.openHTMLFile()
176
+ h2_tags = soup.find_all('h2')
177
+ for idx, h2 in enumerate(h2_tags):
178
+ section_title = h2.get_text(strip=True)
179
+ json.setdefault(section_title, [])
180
+ next_h2 = h2_tags[idx+1] if idx+1 < len(h2_tags) else None
181
+ for p in h2.find_all_next("p"):
182
+ if next_h2 and p == next_h2:
183
+ break
184
+ json[section_title].append(p.get_text(strip=True))
185
+
186
+ # If no sections or explicitly ScienceDirect
187
+ if scienceDirect is not None or len(json) == 0:
188
+ print("async fetching ScienceDirect metadata...")
189
+ api_key = "d0f25e6ae2b275e0d2b68e0e98f68d70"
190
+ doi = self.htmlLink.split("https://doi.org/")[-1]
191
+ base_url = f"https://api.elsevier.com/content/article/doi/{doi}"
192
+ headers = {"Accept": "application/json", "X-ELS-APIKey": api_key}
193
+
194
+ async with aiohttp.ClientSession() as session:
195
+ async with session.get(base_url, headers=headers, timeout=15) as resp:
196
+ if resp.status == 200:
197
+ data = await resp.json()
198
+ if isinstance(data, dict):
199
+ json["fullText"] = data
200
+
201
+ # Merge text
202
+ textJson = self.mergeTextInJson(json)
203
+ textHTML = self.getText()
204
+ return textHTML if len(textHTML) > len(textJson) else textJson
205
+
206
+ except Exception as e:
207
+ print("❌ async_getListSection failed:", e)
208
+ return ""
209
+
210
+ def getListSection(self, scienceDirect=None):
211
+ try:
212
+ json = {}
213
+ text = ""
214
+ textJson, textHTML = "",""
215
+ if scienceDirect == None:
216
+ # soup = self.openHTMLFile()
217
+ # # get list of section
218
+ # json = {}
219
+ # for h2Pos in range(len(soup.find_all('h2'))):
220
+ # if soup.find_all('h2')[h2Pos].text not in json:
221
+ # json[soup.find_all('h2')[h2Pos].text] = []
222
+ # if h2Pos + 1 < len(soup.find_all('h2')):
223
+ # content = soup.find_all('h2')[h2Pos].find_next("p")
224
+ # nexth2Content = soup.find_all('h2')[h2Pos+1].find_next("p")
225
+ # while content.text != nexth2Content.text:
226
+ # json[soup.find_all('h2')[h2Pos].text].append(content.text)
227
+ # content = content.find_next("p")
228
+ # else:
229
+ # content = soup.find_all('h2')[h2Pos].find_all_next("p",string=True)
230
+ # json[soup.find_all('h2')[h2Pos].text] = list(i.text for i in content)
231
+
232
+ soup = self.openHTMLFile()
233
+ h2_tags = soup.find_all('h2')
234
+ json = {}
235
+
236
+ for idx, h2 in enumerate(h2_tags):
237
+ section_title = h2.get_text(strip=True)
238
+ json.setdefault(section_title, [])
239
+
240
+ # Get paragraphs until next H2
241
+ next_h2 = h2_tags[idx+1] if idx+1 < len(h2_tags) else None
242
+ for p in h2.find_all_next("p"):
243
+ if next_h2 and p == next_h2:
244
+ break
245
+ json[section_title].append(p.get_text(strip=True))
246
+ # format
247
+ '''json = {'Abstract':[], 'Introduction':[], 'Methods'[],
248
+ 'Results':[], 'Discussion':[], 'References':[],
249
+ 'Acknowledgements':[], 'Author information':[], 'Ethics declarations':[],
250
+ 'Additional information':[], 'Electronic supplementary material':[],
251
+ 'Rights and permissions':[], 'About this article':[], 'Search':[], 'Navigation':[]}'''
252
+ if scienceDirect!= None or len(json)==0:
253
+ # Replace with your actual Elsevier API key
254
+ api_key = os.environ["SCIENCE_DIRECT_API"]
255
+ # ScienceDirect article DOI or PI (Example DOI)
256
+ doi = self.htmlLink.split("https://doi.org/")[-1] #"10.1016/j.ajhg.2011.01.009"
257
+ # Base URL for the Elsevier API
258
+ base_url = "https://api.elsevier.com/content/article/doi/"
259
+ # Set headers with API key
260
+ headers = {
261
+ "Accept": "application/json",
262
+ "X-ELS-APIKey": api_key
263
+ }
264
+ # Make the API request
265
+ response = requests.get(base_url + doi, headers=headers)
266
+ # Check if the request was successful
267
+ if response.status_code == 200:
268
+ data = response.json()
269
+ supp_data = data["full-text-retrieval-response"]#["coredata"]["link"]
270
+ # if "originalText" in list(supp_data.keys()):
271
+ # if type(supp_data["originalText"])==str:
272
+ # json["originalText"] = [supp_data["originalText"]]
273
+ # if type(supp_data["originalText"])==dict:
274
+ # json["originalText"] = [supp_data["originalText"][key] for key in supp_data["originalText"]]
275
+ # else:
276
+ # if type(supp_data)==dict:
277
+ # for key in supp_data:
278
+ # json[key] = [supp_data[key]]
279
+ if type(data)==dict:
280
+ json["fullText"] = data
281
+ textJson = self.mergeTextInJson(json)
282
+ textHTML = self.getText()
283
+ if len(textHTML) > len(textJson):
284
+ text = textHTML
285
+ else: text = textJson
286
+ return text #json
287
+ except:
288
+ print("failed all")
289
+ return ""
290
+ def getReference(self):
291
+ # get reference to collect more next data
292
+ ref = []
293
+ json = self.getListSection()
294
+ for key in json["References"]:
295
+ ct = cleanText.cleanGenText(key)
296
+ cleanText, filteredWord = ct.cleanText()
297
+ if cleanText not in ref:
298
+ ref.append(cleanText)
299
+ return ref
300
+ def getSupMaterial(self):
301
+ # check if there is material or not
302
+ json = {}
303
+ soup = self.openHTMLFile()
304
+ for h2Pos in range(len(soup.find_all('h2'))):
305
+ if "supplementary" in soup.find_all('h2')[h2Pos].text.lower() or "material" in soup.find_all('h2')[h2Pos].text.lower() or "additional" in soup.find_all('h2')[h2Pos].text.lower() or "support" in soup.find_all('h2')[h2Pos].text.lower():
306
+ #print(soup.find_all('h2')[h2Pos].find_next("a").get("href"))
307
+ link, output = [],[]
308
+ if soup.find_all('h2')[h2Pos].text not in json:
309
+ json[soup.find_all('h2')[h2Pos].text] = []
310
+ for l in soup.find_all('h2')[h2Pos].find_all_next("a",href=True):
311
+ link.append(l["href"])
312
+ if h2Pos + 1 < len(soup.find_all('h2')):
313
+ nexth2Link = soup.find_all('h2')[h2Pos+1].find_next("a",href=True)["href"]
314
+ if nexth2Link in link:
315
+ link = link[:link.index(nexth2Link)]
316
+ # only take links having "https" in that
317
+ for i in link:
318
+ if "https" in i: output.append(i)
319
+ json[soup.find_all('h2')[h2Pos].text].extend(output)
320
+ return json
321
+ def extractTable(self):
322
+ soup = self.openHTMLFile()
323
+ df = []
324
+ if len(soup)>0:
325
+ try:
326
+ df = pd.read_html(str(soup))
327
+ except ValueError:
328
+ df = []
329
+ print("No tables found in HTML file")
330
+ return df
331
+ def mergeTextInJson(self,jsonHTML):
332
+ try:
333
+ #cl = cleanText.cleanGenText()
334
+ htmlText = ""
335
+ if jsonHTML:
336
+ # try:
337
+ # for sec, entries in jsonHTML.items():
338
+ # for i, entry in enumerate(entries):
339
+ # # Only process if it's actually text
340
+ # if isinstance(entry, str):
341
+ # if entry.strip():
342
+ # entry, filteredWord = cl.textPreprocessing(entry, keepPeriod=True)
343
+ # else:
344
+ # # Skip or convert dicts/lists to string if needed
345
+ # entry = str(entry)
346
+
347
+ # jsonHTML[sec][i] = entry
348
+
349
+ # # Add spacing between sentences
350
+ # if i - 1 >= 0 and jsonHTML[sec][i - 1] and jsonHTML[sec][i - 1][-1] != ".":
351
+ # htmlText += ". "
352
+ # htmlText += entry
353
+
354
+ # # Add final period if needed
355
+ # if entries and isinstance(entries[-1], str) and entries[-1] and entries[-1][-1] != ".":
356
+ # htmlText += "."
357
+ # htmlText += "\n\n"
358
+ # except:
359
+ htmlText += str(jsonHTML)
360
+ return htmlText
361
+ except:
362
+ print("failed merge text in json")
363
+ return ""
364
 
app.py CHANGED
The diff for this file is too large to render. See raw diff
 
better_offer.html CHANGED
@@ -1,201 +1,201 @@
1
- <div id="classifier-page">
2
- <style>
3
- /* Force light mode inside this block only */
4
- #classifier-page {
5
- background: #ffffff !important;
6
- color: #0f172a !important;
7
- font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif !important;
8
-
9
- max-width: 900px !important;
10
- margin: 24px auto !important;
11
- padding: 28px 20px !important;
12
- border-radius: 12px !important;
13
- box-shadow: 0 2px 8px rgba(0,0,0,0.05) !important;
14
- }
15
-
16
- /* All text dark */
17
- #classifier-page * {
18
- color: #0f172a !important;
19
- }
20
-
21
- /* Pills */
22
- #classifier-page .pill {
23
- display: inline-block !important;
24
- padding: 6px 10px !important;
25
- border-radius: 999px !important;
26
- font-size: 12px !important;
27
- font-weight: 500 !important;
28
- margin: 2px !important;
29
- }
30
- #classifier-page .pill-blue { background:#eef2ff !important; color:#3730a3 !important; }
31
- #classifier-page .pill-cyan { background:#ecfeff !important; color:#155e75 !important; }
32
- #classifier-page .pill-green { background:#f0fdf4 !important; color:#166534 !important; }
33
- #classifier-page .pill-orange { background:#fff7ed !important; color:#9a3412 !important; }
34
-
35
- /* Explicitly restore button background overrides */
36
- #classifier-page a {
37
- text-decoration: none !important;
38
- }
39
- #classifier-page a[href*="mtDNALocation"] {
40
- background: #111827 !important; /* black */
41
- color: #ffffff !important; /* white text */
42
- text-decoration: none !important;
43
- padding: 12px 16px !important;
44
- border-radius: 10px !important;
45
- font-weight: 600 !important;
46
- display: inline-block !important;
47
- }
48
-
49
-
50
- </style>
51
- <!-- Header -->
52
- <h1 style="margin:0 0 8px; font-size:32px;">mtDNA Location Classifier</h1>
53
- <p style="margin:0 0 16px; font-size:18px; color:#334155;">
54
- <strong>AI + Human Intelligence, working together.</strong><br>
55
- The tool suggests structured labels fast — you decide which ones to trust and refine.
56
- </p>
57
-
58
- <!-- Badges -->
59
- <div style="display:flex; gap:8px; flex-wrap:wrap; margin:12px 0 24px;">
60
- <span class="pill pill-blue">84% country accuracy (n=4,934)</span>
61
- <span class="pill pill-cyan">92% modern/ancient accuracy (n=4,934)</span>
62
- <span class="pill pill-green">Source-backed explanations</span>
63
- <span class="pill pill-orange">Report → free credit</span>
64
- </div>
65
-
66
- <hr style="border:none; border-top:1px solid #e2e8f0; margin:24px 0;">
67
-
68
- <!-- Purpose -->
69
- <h2 style="margin:0 0 8px; font-size:22px;">Purpose</h2>
70
- <p style="margin:0 0 12px;">
71
- Make biological data <strong>reusable</strong> by labeling it better.
72
- <br><br>
73
- Many GenBank / NCBI samples have <strong>incomplete/missing metadata</strong> (country, sample type, optional ethnicity/specific location).
74
- This tool helps researchers generate <strong>clean, structured labels</strong> — ready for papers, datasets, or analysis.
75
- <br><br>
76
- <em>This is not a black-box AI. It’s a partnership between AI speed and human expertise.</em>
77
- </p>
78
-
79
- <!-- What you get -->
80
- <h2 style="margin:24px 0 8px; font-size:22px;">What you get</h2>
81
- <ul style="margin:0 0 12px 18px;">
82
- <li>AI-powered inference from GenBank accession alone.</li>
83
- <li>Country + Sample Type by default; optional labels (e.g. ethnicity &amp; specific location) on request.</li>
84
- <li>Transparent outputs: explanations, citations.</li>
85
- <li>Excel export; batch upload; multi-ID input.</li>
86
- <li><strong>Human-in-the-loop control:</strong> 1-click feedback ensures you decide what counts.</li>
87
- </ul>
88
-
89
- <div style="background:#f8fafc; border:1px solid #e2e8f0; border-radius:12px; padding:14px; margin:16px 0;">
90
- <strong>Positioning:</strong> This tool is an <em>accelerator</em>, not a replacement.
91
- AI surfaces leads quickly → Human Intelligence validates tricky cases.
92
- </div>
93
-
94
- <hr style="border:none; border-top:1px solid #e2e8f0; margin:24px 0;">
95
-
96
- <!-- Free tier -->
97
- <h2 style="margin:0 0 8px; font-size:22px;">Free tier</h2>
98
- <ul style="margin:0 0 12px 18px;">
99
- <li><strong>30</strong> free samples (no email).</li>
100
- <li>Add email → <strong>+20</strong> bonus samples (<strong>50 total</strong>) and downloads.</li>
101
- <li>Not satisfied? Click “Report” → that row doesn’t count, and you get a credit back.</li>
102
- </ul>
103
-
104
- <!-- Pricing -->
105
- <h2 style="margin:24px 0 8px; font-size:22px;">Simple pricing</h2>
106
- <table style="width:100%; border-collapse:collapse; border:1px solid #e2e8f0; border-radius:10px; overflow:hidden;">
107
- <thead>
108
- <tr style="background:#f1f5f9;">
109
- <th style="text-align:left; padding:10px; font-weight:600;">Plan</th>
110
- <th style="text-align:left; padding:10px; font-weight:600;">What’s included</th>
111
- <th style="text-align:left; padding:10px; font-weight:600;">Price</th>
112
- </tr>
113
- </thead>
114
- <tbody>
115
- <tr>
116
- <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>Pay-as-you-go (especially for <a href="#edge_cases" style="color:#1d4ed8; text-decoration:underline;">edge cases</a>)</strong></td>
117
- <td style="padding:10px; border-top:1px solid #e2e8f0;">Country + Sample Type, explanations, citations, export, report→credit</td>
118
- <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>$0.10 / sample</strong></td>
119
- </tr>
120
- <tr>
121
- <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>Custom labels (optional)</strong></td>
122
- <td style="padding:10px; border-top:1px solid #e2e8f0;">Ethnicity, specific location granularity, phenotype, or bespoke fields</td>
123
- <td style="padding:10px; border-top:1px solid #e2e8f0;">Quote on request</td>
124
- </tr>
125
- <tr style="background:#fcfcff;">
126
- <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong> <a href="#research-supporter" style="color:#1d4ed8; text-decoration:underline;">Research Partner (Supporter) </a></strong></td>
127
- <td style="padding:10px; border-top:1px solid #e2e8f0;">~3,000 samples worth of credits + early access, custom label runs, direct feedback channel, recognition</td>
128
- <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>$300 / 3 months</strong></td>
129
- </tr>
130
- </tbody>
131
- </table>
132
-
133
- <p style="margin:10px 0 0; font-size:14px; color:#475569;">
134
- <em>Note:</em> For very small sets that you can easily verify manually, we’ll advise you to skip paid runs.
135
- We optimize for your outcomes, not usage.
136
- </p>
137
-
138
- <!-- Edge case highlight -->
139
- <h2 id="edge_cases" style="margin:24px 0 8px; font-size:22px;">Edge Cases (our specialty)</h2>
140
- <p>
141
- Some samples are especially hard to label because they don’t have a DOI, PubMed ID, or linked article.
142
- Normally these are ignored or left as “unknown.” We call them <strong>edge cases</strong>.
143
- </p>
144
- <ul>
145
- <li>Priced the same as normal runs ($0.10/sample) — no penalty for difficulty</li>
146
- <li>Custom labels (e.g. ethnicity, city/province) can also be applied to edge cases on request</li>
147
- </ul>
148
- <div style="background:#fff7ed; border:1px solid #fed7aa; border-radius:12px; padding:14px; margin:20px 0;">
149
- <strong>Why it matters:</strong> One early researcher tested 4,932 samples and found our predictions
150
- for some <em>edge cases</em> were more accurate than his manual annotations — even when metadata was missing.
151
- </div>
152
-
153
- <hr>
154
-
155
- <h2 id="research-supporter" style="margin:24px 0 8px; font-size:22px;">Research Partner Plan (for early supporters)</h2>
156
- <p>
157
- Designed for researchers running larger studies who want to support ongoing development while staying on budget.
158
- Instead of paying strictly per sample, you can join as a <strong>Research Partner</strong>:
159
- </p>
160
- <ul>
161
- <li><strong>$300 flat contribution</strong> (covers ~3,000 samples at $0.10 each, with flexibility on usage)</li>
162
- <li>Includes early access to new features and custom labels</li>
163
- <li>Direct feedback channel — help shape how the tool evolves</li>
164
- <li>Recognition as an early research supporter</li>
165
- </ul>
166
- <p>
167
- <em>This tier was inspired by our very first paying researcher, who contributed $300 to support
168
- continued development after testing thousands of samples. It’s ideal if you see the potential
169
- and want to support the mission, even if you’re still validating outputs in your workflow.</em>
170
- </p>
171
-
172
- <!-- Who it's for -->
173
- <h2 style="margin:24px 0 8px; font-size:22px;">Best for</h2>
174
- <ul style="margin:0 0 12px 18px;">
175
- <li>Labs cleaning large mtDNA cohorts where manual labeling is slow or inconsistent.</li>
176
- <li>Researchers who want fast leads + citations, then validate edge cases themselves.</li>
177
- <li>Teams that value transparency and iterative improvement.</li>
178
- </ul>
179
-
180
- <!-- CTA -->
181
- <div style="display:flex; gap:12px; flex-wrap:wrap; margin:20px 0;">
182
- <a href="https://huggingface.co/spaces/VyLala/mtDNALocation" target="_blank"
183
- style="background:#111827; color:#fff; text-decoration:none; padding:12px 16px; border-radius:10px; font-weight:600;">
184
- Try the Classifier
185
- </a>
186
- <a href="mailto:khanhphungvy@gmail.com" target="_blank"
187
- style="background:#e2e8f0; color:#0f172a; text-decoration:none; padding:12px 16px; border-radius:10px; font-weight:600;">
188
- Bulk / Research Partner request
189
- </a>
190
- </div>
191
-
192
- <hr style="border:none; border-top:1px solid #e2e8f0; margin:24px 0;">
193
-
194
- <!-- Mission -->
195
- <h2 style="margin:0 0 8px; font-size:22px;">Mission</h2>
196
- <p style="margin:0;">
197
- Rebuild trust in genomic metadata—one mtDNA sample at a time—through transparency, citations, and a tight feedback loop with researchers.
198
- </p>
199
- </div>
200
-
201
-
 
1
+ <div id="classifier-page">
2
+ <style>
3
+ /* Force light mode inside this block only */
4
+ #classifier-page {
5
+ background: #ffffff !important;
6
+ color: #0f172a !important;
7
+ font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif !important;
8
+
9
+ max-width: 900px !important;
10
+ margin: 24px auto !important;
11
+ padding: 28px 20px !important;
12
+ border-radius: 12px !important;
13
+ box-shadow: 0 2px 8px rgba(0,0,0,0.05) !important;
14
+ }
15
+
16
+ /* All text dark */
17
+ #classifier-page * {
18
+ color: #0f172a !important;
19
+ }
20
+
21
+ /* Pills */
22
+ #classifier-page .pill {
23
+ display: inline-block !important;
24
+ padding: 6px 10px !important;
25
+ border-radius: 999px !important;
26
+ font-size: 12px !important;
27
+ font-weight: 500 !important;
28
+ margin: 2px !important;
29
+ }
30
+ #classifier-page .pill-blue { background:#eef2ff !important; color:#3730a3 !important; }
31
+ #classifier-page .pill-cyan { background:#ecfeff !important; color:#155e75 !important; }
32
+ #classifier-page .pill-green { background:#f0fdf4 !important; color:#166534 !important; }
33
+ #classifier-page .pill-orange { background:#fff7ed !important; color:#9a3412 !important; }
34
+
35
+ /* Explicitly restore button background overrides */
36
+ #classifier-page a {
37
+ text-decoration: none !important;
38
+ }
39
+ #classifier-page a[href*="mtDNALocation"] {
40
+ background: #111827 !important; /* black */
41
+ color: #ffffff !important; /* white text */
42
+ text-decoration: none !important;
43
+ padding: 12px 16px !important;
44
+ border-radius: 10px !important;
45
+ font-weight: 600 !important;
46
+ display: inline-block !important;
47
+ }
48
+
49
+
50
+ </style>
51
+ <!-- Header -->
52
+ <h1 style="margin:0 0 8px; font-size:32px;">mtDNA Location Classifier</h1>
53
+ <p style="margin:0 0 16px; font-size:18px; color:#334155;">
54
+ <strong>AI + Human Intelligence, working together.</strong><br>
55
+ The tool suggests structured labels fast — you decide which ones to trust and refine.
56
+ </p>
57
+
58
+ <!-- Badges -->
59
+ <div style="display:flex; gap:8px; flex-wrap:wrap; margin:12px 0 24px;">
60
+ <span class="pill pill-blue">84% country accuracy (n=4,934)</span>
61
+ <span class="pill pill-cyan">92% modern/ancient accuracy (n=4,934)</span>
62
+ <span class="pill pill-green">Source-backed explanations</span>
63
+ <span class="pill pill-orange">Report → free credit</span>
64
+ </div>
65
+
66
+ <hr style="border:none; border-top:1px solid #e2e8f0; margin:24px 0;">
67
+
68
+ <!-- Purpose -->
69
+ <h2 style="margin:0 0 8px; font-size:22px;">Purpose</h2>
70
+ <p style="margin:0 0 12px;">
71
+ Make biological data <strong>reusable</strong> by labeling it better.
72
+ <br><br>
73
+ Many GenBank / NCBI samples have <strong>incomplete/missing metadata</strong> (country, sample type, optional ethnicity/specific location).
74
+ This tool helps researchers generate <strong>clean, structured labels</strong> — ready for papers, datasets, or analysis.
75
+ <br><br>
76
+ <em>This is not a black-box AI. It’s a partnership between AI speed and human expertise.</em>
77
+ </p>
78
+
79
+ <!-- What you get -->
80
+ <h2 style="margin:24px 0 8px; font-size:22px;">What you get</h2>
81
+ <ul style="margin:0 0 12px 18px;">
82
+ <li>AI-powered inference from GenBank accession alone.</li>
83
+ <li>Country + Sample Type by default; optional labels (e.g. ethnicity &amp; specific location) on request.</li>
84
+ <li>Transparent outputs: explanations, citations.</li>
85
+ <li>Excel export; batch upload; multi-ID input.</li>
86
+ <li><strong>Human-in-the-loop control:</strong> 1-click feedback ensures you decide what counts.</li>
87
+ </ul>
88
+
89
+ <div style="background:#f8fafc; border:1px solid #e2e8f0; border-radius:12px; padding:14px; margin:16px 0;">
90
+ <strong>Positioning:</strong> This tool is an <em>accelerator</em>, not a replacement.
91
+ AI surfaces leads quickly → Human Intelligence validates tricky cases.
92
+ </div>
93
+
94
+ <hr style="border:none; border-top:1px solid #e2e8f0; margin:24px 0;">
95
+
96
+ <!-- Free tier -->
97
+ <h2 style="margin:0 0 8px; font-size:22px;">Free tier</h2>
98
+ <ul style="margin:0 0 12px 18px;">
99
+ <li><strong>30</strong> free samples (no email).</li>
100
+ <li>Add email → <strong>+20</strong> bonus samples (<strong>50 total</strong>) and downloads.</li>
101
+ <li>Not satisfied? Click “Report” → that row doesn’t count, and you get a credit back.</li>
102
+ </ul>
103
+
104
+ <!-- Pricing -->
105
+ <h2 style="margin:24px 0 8px; font-size:22px;">Simple pricing</h2>
106
+ <table style="width:100%; border-collapse:collapse; border:1px solid #e2e8f0; border-radius:10px; overflow:hidden;">
107
+ <thead>
108
+ <tr style="background:#f1f5f9;">
109
+ <th style="text-align:left; padding:10px; font-weight:600;">Plan</th>
110
+ <th style="text-align:left; padding:10px; font-weight:600;">What’s included</th>
111
+ <th style="text-align:left; padding:10px; font-weight:600;">Price</th>
112
+ </tr>
113
+ </thead>
114
+ <tbody>
115
+ <tr>
116
+ <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>Pay-as-you-go (especially for <a href="#edge_cases" style="color:#1d4ed8; text-decoration:underline;">edge cases</a>)</strong></td>
117
+ <td style="padding:10px; border-top:1px solid #e2e8f0;">Country + Sample Type, explanations, citations, export, report→credit</td>
118
+ <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>$0.10 / sample</strong></td>
119
+ </tr>
120
+ <tr>
121
+ <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>Custom labels (optional)</strong></td>
122
+ <td style="padding:10px; border-top:1px solid #e2e8f0;">Ethnicity, specific location granularity, phenotype, or bespoke fields</td>
123
+ <td style="padding:10px; border-top:1px solid #e2e8f0;">Quote on request</td>
124
+ </tr>
125
+ <tr style="background:#fcfcff;">
126
+ <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong> <a href="#research-supporter" style="color:#1d4ed8; text-decoration:underline;">Research Partner (Supporter) </a></strong></td>
127
+ <td style="padding:10px; border-top:1px solid #e2e8f0;">~3,000 samples worth of credits + early access, custom label runs, direct feedback channel, recognition</td>
128
+ <td style="padding:10px; border-top:1px solid #e2e8f0;"><strong>$300 / 3 months</strong></td>
129
+ </tr>
130
+ </tbody>
131
+ </table>
132
+
133
+ <p style="margin:10px 0 0; font-size:14px; color:#475569;">
134
+ <em>Note:</em> For very small sets that you can easily verify manually, we’ll advise you to skip paid runs.
135
+ We optimize for your outcomes, not usage.
136
+ </p>
137
+
138
+ <!-- Edge case highlight -->
139
+ <h2 id="edge_cases" style="margin:24px 0 8px; font-size:22px;">Edge Cases (our specialty)</h2>
140
+ <p>
141
+ Some samples are especially hard to label because they don’t have a DOI, PubMed ID, or linked article.
142
+ Normally these are ignored or left as “unknown.” We call them <strong>edge cases</strong>.
143
+ </p>
144
+ <ul>
145
+ <li>Priced the same as normal runs ($0.10/sample) — no penalty for difficulty</li>
146
+ <li>Custom labels (e.g. ethnicity, city/province) can also be applied to edge cases on request</li>
147
+ </ul>
148
+ <div style="background:#fff7ed; border:1px solid #fed7aa; border-radius:12px; padding:14px; margin:20px 0;">
149
+ <strong>Why it matters:</strong> One early researcher tested 4,932 samples and found our predictions
150
+ for some <em>edge cases</em> were more accurate than his manual annotations — even when metadata was missing.
151
+ </div>
152
+
153
+ <hr>
154
+
155
+ <h2 id="research-supporter" style="margin:24px 0 8px; font-size:22px;">Research Partner Plan (for early supporters)</h2>
156
+ <p>
157
+ Designed for researchers running larger studies who want to support ongoing development while staying on budget.
158
+ Instead of paying strictly per sample, you can join as a <strong>Research Partner</strong>:
159
+ </p>
160
+ <ul>
161
+ <li><strong>$300 flat contribution</strong> (covers ~3,000 samples at $0.10 each, with flexibility on usage)</li>
162
+ <li>Includes early access to new features and custom labels</li>
163
+ <li>Direct feedback channel — help shape how the tool evolves</li>
164
+ <li>Recognition as an early research supporter</li>
165
+ </ul>
166
+ <p>
167
+ <em>This tier was inspired by our very first paying researcher, who contributed $300 to support
168
+ continued development after testing thousands of samples. It’s ideal if you see the potential
169
+ and want to support the mission, even if you’re still validating outputs in your workflow.</em>
170
+ </p>
171
+
172
+ <!-- Who it's for -->
173
+ <h2 style="margin:24px 0 8px; font-size:22px;">Best for</h2>
174
+ <ul style="margin:0 0 12px 18px;">
175
+ <li>Labs cleaning large mtDNA cohorts where manual labeling is slow or inconsistent.</li>
176
+ <li>Researchers who want fast leads + citations, then validate edge cases themselves.</li>
177
+ <li>Teams that value transparency and iterative improvement.</li>
178
+ </ul>
179
+
180
+ <!-- CTA -->
181
+ <div style="display:flex; gap:12px; flex-wrap:wrap; margin:20px 0;">
182
+ <a href="https://huggingface.co/spaces/VyLala/mtDNALocation" target="_blank"
183
+ style="background:#111827; color:#fff; text-decoration:none; padding:12px 16px; border-radius:10px; font-weight:600;">
184
+ Try the Classifier
185
+ </a>
186
+ <a href="mailto:khanhphungvy@gmail.com" target="_blank"
187
+ style="background:#e2e8f0; color:#0f172a; text-decoration:none; padding:12px 16px; border-radius:10px; font-weight:600;">
188
+ Bulk / Research Partner request
189
+ </a>
190
+ </div>
191
+
192
+ <hr style="border:none; border-top:1px solid #e2e8f0; margin:24px 0;">
193
+
194
+ <!-- Mission -->
195
+ <h2 style="margin:0 0 8px; font-size:22px;">Mission</h2>
196
+ <p style="margin:0;">
197
+ Rebuild trust in genomic metadata—one mtDNA sample at a time—through transparency, citations, and a tight feedback loop with researchers.
198
+ </p>
199
+ </div>
200
+
201
+
data_preprocess.py CHANGED
@@ -1,877 +1,877 @@
1
- import re
2
- import os
3
- #import streamlit as st
4
- import subprocess
5
- import re
6
- from Bio import Entrez
7
- from docx import Document
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- #from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- import pandas as pd
19
- import model
20
- import pipeline
21
- import tempfile
22
- import nltk
23
- nltk.download('punkt_tab')
24
- def download_excel_file(url, save_path="temp.xlsx"):
25
- if "view.officeapps.live.com" in url:
26
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
27
- real_url = urllib.parse.unquote(parsed_url["src"][0])
28
- response = requests.get(real_url)
29
- with open(save_path, "wb") as f:
30
- f.write(response.content)
31
- return save_path
32
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
33
- response = requests.get(url)
34
- response.raise_for_status() # Raises error if download fails
35
- with open(save_path, "wb") as f:
36
- f.write(response.content)
37
- print(len(response.content))
38
- return save_path
39
- else:
40
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
41
- return url
42
-
43
- from pathlib import Path
44
- import pandas as pd
45
-
46
- def process_file(link, saveFolder):
47
- """Returns (file_type, full_path, name) for a given link."""
48
- name = Path(link).name
49
- ext = Path(name).suffix.lower()
50
- file_path = Path(saveFolder) / name
51
-
52
- # If it's already in saveFolder, update link to local path
53
- if file_path.is_file():
54
- link = str(file_path)
55
-
56
- return ext, link, file_path
57
-
58
- import asyncio
59
- import aiohttp
60
- _html_cache = {}
61
-
62
- async def async_fetch_html(link: str, timeout: int = 15) -> str:
63
- """Fetch HTML asynchronously with caching."""
64
- if link in _html_cache:
65
- return _html_cache[link]
66
-
67
- try:
68
- async with aiohttp.ClientSession() as session:
69
- async with session.get(link, timeout=timeout) as resp:
70
- if resp.status != 200:
71
- print(f"⚠️ Failed {link} ({resp.status})")
72
- return ""
73
- html_content = await resp.text()
74
- _html_cache[link] = html_content
75
- return html_content
76
- except Exception as e:
77
- print(f"❌ async_fetch_html error for {link}: {e}")
78
- return ""
79
-
80
- async def ensure_local_file(link: str, saveFolder: str) -> str:
81
- """Ensure file is available locally (Drive or web). Returns local path."""
82
- name = link.split("/")[-1]
83
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
84
-
85
- if os.path.exists(local_temp_path):
86
- return local_temp_path
87
-
88
- # Try Drive first (blocking → offload)
89
- file_id = await asyncio.to_thread(pipeline.find_drive_file, name, saveFolder)
90
- if file_id:
91
- await asyncio.to_thread(pipeline.download_file_from_drive, name, saveFolder, local_temp_path)
92
- else:
93
- # Web download asynchronously
94
- async with aiohttp.ClientSession() as session:
95
- async with session.get(link, timeout=20) as resp:
96
- resp.raise_for_status()
97
- content = await resp.read()
98
- with open(local_temp_path, "wb") as f:
99
- f.write(content)
100
- # Upload back to Drive (offload)
101
- await asyncio.to_thread(pipeline.upload_file_to_drive, local_temp_path, name, saveFolder)
102
-
103
- return local_temp_path
104
-
105
- async def async_extract_text(link, saveFolder):
106
- try:
107
- if link.endswith(".pdf"):
108
- local_path = await ensure_local_file(link, saveFolder)
109
- return await asyncio.to_thread(lambda: pdf.PDFFast(local_path, saveFolder).extract_text())
110
-
111
- elif link.endswith((".doc", ".docx")):
112
- local_path = await ensure_local_file(link, saveFolder)
113
- return await asyncio.to_thread(lambda: wordDoc.WordDocFast(local_path, saveFolder).extractText())
114
-
115
- elif link.endswith((".xls", ".xlsx")):
116
- return ""
117
-
118
- elif link.startswith("http") or "html" in link:
119
- html_content = await async_fetch_html(link)
120
- html = extractHTML.HTML(htmlContent=html_content, htmlLink=link, htmlFile="")
121
- # If you implement async_getListSection, call it here
122
- if hasattr(html, "async_getListSection"):
123
- article_text = await html.async_getListSection()
124
- else:
125
- # fallback: run sync getListSection in a thread
126
- article_text = await asyncio.to_thread(html.getListSection)
127
-
128
- if not article_text:
129
- metadata_text = html.fetch_crossref_metadata(link)
130
- if metadata_text:
131
- article_text = html.mergeTextInJson(metadata_text)
132
- return article_text
133
-
134
- else:
135
- return ""
136
- except Exception as e:
137
- print(f"❌ async_extract_text failed for {link}: {e}")
138
- return ""
139
-
140
-
141
- def extract_text(link,saveFolder):
142
- try:
143
- text = ""
144
- name = link.split("/")[-1]
145
- print("name: ", name)
146
- #file_path = Path(saveFolder) / name
147
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
148
- print("this is local temp path: ", local_temp_path)
149
- if os.path.exists(local_temp_path):
150
- input_to_class = local_temp_path
151
- print("exist")
152
- else:
153
- #input_to_class = link # Let the class handle downloading
154
- # 1. Check if file exists in shared Google Drive folder
155
- file_id = pipeline.find_drive_file(name, saveFolder)
156
- if file_id:
157
- print("📥 Downloading from Google Drive...")
158
- pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
159
- else:
160
- print("🌐 Downloading from web link...")
161
- response = requests.get(link)
162
- with open(local_temp_path, 'wb') as f:
163
- f.write(response.content)
164
- print("✅ Saved locally.")
165
-
166
- # 2. Upload to Drive so it's available for later
167
- pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
168
-
169
- input_to_class = local_temp_path
170
- print(input_to_class)
171
- # pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
172
- # pdf
173
- if link.endswith(".pdf"):
174
- # if file_path.is_file():
175
- # link = saveFolder + "/" + name
176
- # print("File exists.")
177
- #p = pdf.PDF(local_temp_path, saveFolder)
178
- print("inside pdf and input to class: ", input_to_class)
179
- print("save folder in extract text: ", saveFolder)
180
- #p = pdf.PDF(input_to_class, saveFolder)
181
- #p = pdf.PDF(link,saveFolder)
182
- #text = p.extractTextWithPDFReader()
183
- #text = p.extractText()
184
- p = pdf.PDFFast(input_to_class, saveFolder)
185
- text = p.extract_text()
186
-
187
- print("len text from pdf:")
188
- print(len(text))
189
- #text_exclude_table = p.extract_text_excluding_tables()
190
- # worddoc
191
- elif link.endswith(".doc") or link.endswith(".docx"):
192
- #d = wordDoc.wordDoc(local_temp_path,saveFolder)
193
- # d = wordDoc.wordDoc(input_to_class,saveFolder)
194
- # text = d.extractTextByPage()
195
- d = wordDoc.WordDocFast(input_to_class, saveFolder)
196
- text = d.extractText()
197
-
198
- # html
199
- else:
200
- if link.split(".")[-1].lower() not in "xlsx":
201
- if "http" in link or "html" in link:
202
- print("html link: ", link)
203
- html = extractHTML.HTML("",link)
204
- text = html.getListSection() # the text already clean
205
- print("len text html: ")
206
- print(len(text))
207
- # Cleanup: delete the local temp file
208
- if name:
209
- if os.path.exists(local_temp_path):
210
- os.remove(local_temp_path)
211
- print(f"🧹 Deleted local temp file: {local_temp_path}")
212
- print("done extract text")
213
- except:
214
- text = ""
215
- return text
216
-
217
- def extract_table(link,saveFolder):
218
- try:
219
- table = []
220
- name = link.split("/")[-1]
221
- #file_path = Path(saveFolder) / name
222
- local_temp_path = os.path.join(tempfile.gettempdir(), name)
223
- if os.path.exists(local_temp_path):
224
- input_to_class = local_temp_path
225
- print("exist")
226
- else:
227
- #input_to_class = link # Let the class handle downloading
228
- # 1. Check if file exists in shared Google Drive folder
229
- file_id = pipeline.find_drive_file(name, saveFolder)
230
- if file_id:
231
- print("📥 Downloading from Google Drive...")
232
- pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
233
- else:
234
- print("🌐 Downloading from web link...")
235
- response = requests.get(link)
236
- with open(local_temp_path, 'wb') as f:
237
- f.write(response.content)
238
- print("✅ Saved locally.")
239
-
240
- # 2. Upload to Drive so it's available for later
241
- pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
242
-
243
- input_to_class = local_temp_path
244
- print(input_to_class)
245
- #pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
246
- # pdf
247
- if link.endswith(".pdf"):
248
- # if file_path.is_file():
249
- # link = saveFolder + "/" + name
250
- # print("File exists.")
251
- #p = pdf.PDF(local_temp_path,saveFolder)
252
- p = pdf.PDF(input_to_class,saveFolder)
253
- table = p.extractTable()
254
- # worddoc
255
- elif link.endswith(".doc") or link.endswith(".docx"):
256
- #d = wordDoc.wordDoc(local_temp_path,saveFolder)
257
- # d = wordDoc.wordDoc(input_to_class,saveFolder)
258
- # table = d.extractTableAsList()
259
- d = wordDoc.WordDocFast(input_to_class, saveFolder)
260
- table = d.extractTableAsList()
261
- # excel
262
- elif link.split(".")[-1].lower() in "xlsx":
263
- # download excel file if it not downloaded yet
264
- savePath = saveFolder +"/"+ link.split("/")[-1]
265
- excelPath = download_excel_file(link, savePath)
266
- try:
267
- #xls = pd.ExcelFile(excelPath)
268
- xls = pd.ExcelFile(local_temp_path)
269
- table_list = []
270
- for sheet_name in xls.sheet_names:
271
- df = pd.read_excel(xls, sheet_name=sheet_name)
272
- cleaned_table = df.fillna("").astype(str).values.tolist()
273
- table_list.append(cleaned_table)
274
- table = table_list
275
- except Exception as e:
276
- print("❌ Failed to extract tables from Excel:", e)
277
- # html
278
- elif "http" in link or "html" in link:
279
- html = extractHTML.HTML("",link)
280
- table = html.extractTable() # table is a list
281
- table = clean_tables_format(table)
282
- # Cleanup: delete the local temp file
283
- if os.path.exists(local_temp_path):
284
- os.remove(local_temp_path)
285
- print(f"🧹 Deleted local temp file: {local_temp_path}")
286
- except:
287
- table = []
288
- return table
289
-
290
- def clean_tables_format(tables):
291
- """
292
- Ensures all tables are in consistent format: List[List[List[str]]]
293
- Cleans by:
294
- - Removing empty strings and rows
295
- - Converting all cells to strings
296
- - Handling DataFrames and list-of-lists
297
- """
298
- cleaned = []
299
- if tables:
300
- for table in tables:
301
- standardized = []
302
-
303
- # Case 1: Pandas DataFrame
304
- if isinstance(table, pd.DataFrame):
305
- table = table.fillna("").astype(str).values.tolist()
306
-
307
- # Case 2: List of Lists
308
- if isinstance(table, list) and all(isinstance(row, list) for row in table):
309
- for row in table:
310
- filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
311
- if filtered_row:
312
- standardized.append(filtered_row)
313
-
314
- if standardized:
315
- cleaned.append(standardized)
316
-
317
- return cleaned
318
-
319
- import json
320
- def normalize_text_for_comparison(s: str) -> str:
321
- """
322
- Normalizes text for robust comparison by:
323
- 1. Converting to lowercase.
324
- 2. Replacing all types of newlines with a single consistent newline (\n).
325
- 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
326
- 4. Stripping leading/trailing whitespace from the entire string.
327
- """
328
- s = s.lower()
329
- s = s.replace('\r\n', '\n') # Handle Windows newlines
330
- s = s.replace('\r', '\n') # Handle Mac classic newlines
331
-
332
- # Replace sequences of whitespace (including multiple newlines) with a single space
333
- # This might be too aggressive if you need to preserve paragraph breaks,
334
- # but good for exact word-sequence matching.
335
- s = re.sub(r'\s+', ' ', s)
336
-
337
- return s.strip()
338
- def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
339
- """
340
- Merge cleaned text and table into one string for LLM input.
341
- - Avoids duplicating tables already in text
342
- - Extracts only relevant rows from large tables
343
- - Skips or saves oversized tables
344
- """
345
- import importlib
346
- json = importlib.import_module("json")
347
-
348
- def estimate_tokens(text_str):
349
- try:
350
- enc = tiktoken.get_encoding(tokenizer)
351
- return len(enc.encode(text_str))
352
- except:
353
- return len(text_str) // 4 # Fallback estimate
354
-
355
- def is_table_relevant(table, keywords, accession_id=None):
356
- flat = " ".join(" ".join(row).lower() for row in table)
357
- if accession_id and accession_id.lower() in flat:
358
- return True
359
- return any(kw.lower() in flat for kw in keywords)
360
- preview, preview1 = "",""
361
- llm_input = "## Document Text\n" + text.strip() + "\n"
362
- clean_text = normalize_text_for_comparison(text)
363
-
364
- if tables:
365
- for idx, table in enumerate(tables):
366
- keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
367
- if accession_id: keywords += [accession_id.lower()]
368
- if isolate: keywords += [isolate.lower()]
369
- if is_table_relevant(table, keywords, accession_id):
370
- if len(table) > 0:
371
- for tab in table:
372
- preview = " ".join(tab) if tab else ""
373
- preview1 = "\n".join(tab) if tab else ""
374
- clean_preview = normalize_text_for_comparison(preview)
375
- clean_preview1 = normalize_text_for_comparison(preview1)
376
- if clean_preview not in clean_text:
377
- if clean_preview1 not in clean_text:
378
- table_str = json.dumps([tab], indent=2)
379
- llm_input += f"## Table {idx+1}\n{table_str}\n"
380
- return llm_input.strip()
381
-
382
- def preprocess_document(link, saveFolder, accession=None, isolate=None, article_text=None):
383
- if article_text:
384
- print("article text already available")
385
- text = article_text
386
- else:
387
- try:
388
- print("start preprocess and extract text")
389
- text = extract_text(link, saveFolder)
390
- except: text = ""
391
- try:
392
- print("extract table start")
393
- success, the_output = pipeline.run_with_timeout(extract_table,args=(link,saveFolder),timeout=10)
394
- print("Returned from timeout logic")
395
- if success:
396
- tables = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
397
- print("yes succeed for extract table")
398
- else:
399
- print("not suceed etxract table")
400
- tables = []
401
- #tables = extract_table(link, saveFolder)
402
- except: tables = []
403
- if accession: accession = accession
404
- if isolate: isolate = isolate
405
- try:
406
- # print("merge text and table start")
407
- # success, the_output = pipeline.run_with_timeout(merge_text_and_tables,kwargs={"text":text,"tables":tables,"accession_id":accession, "isolate":isolate},timeout=30)
408
- # print("Returned from timeout logic")
409
- # if success:
410
- # final_input = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
411
- # print("yes succeed")
412
- # else:
413
- # print("not suceed")
414
- print("just merge text and tables")
415
- final_input = text + ", ".join(tables)
416
- #final_input = pipeline.timeout(merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
417
- except:
418
- print("no succeed here in preprocess docu")
419
- final_input = ""
420
- return text, tables, final_input
421
-
422
- def extract_sentences(text):
423
- sentences = re.split(r'(?<=[.!?])\s+', text)
424
- return [s.strip() for s in sentences if s.strip()]
425
-
426
- def is_irrelevant_number_sequence(text):
427
- if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
428
- return False
429
- word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
430
- number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
431
- total_tokens = len(re.findall(r'\S+', text))
432
- if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
433
- return True
434
- elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
435
- return True
436
- return False
437
-
438
- def remove_isolated_single_digits(sentence):
439
- tokens = sentence.split()
440
- filtered_tokens = []
441
- for token in tokens:
442
- if token == '0' or token == '1':
443
- pass
444
- else:
445
- filtered_tokens.append(token)
446
- return ' '.join(filtered_tokens).strip()
447
-
448
- def get_contextual_sentences_BFS(text_content, keyword, depth=2):
449
- def extract_codes(sentence):
450
- # Match codes like 'A1YU101', 'KM1', 'MO6' — at least 2 letters + numbers
451
- return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
452
- sentences = extract_sentences(text_content)
453
- relevant_sentences = set()
454
- initial_keywords = set()
455
-
456
- # Define a regex to capture codes like A1YU101 or KM1
457
- # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
458
- code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
459
-
460
- # Attempt to parse the keyword into its prefix and numerical part using re.search
461
- keyword_match = code_pattern.search(keyword)
462
-
463
- keyword_prefix = None
464
- keyword_num = None
465
-
466
- if keyword_match:
467
- keyword_prefix = keyword_match.group(1).lower()
468
- keyword_num = int(keyword_match.group(2))
469
-
470
- for sentence in sentences:
471
- sentence_added = False
472
-
473
- # 1. Check for exact match of the keyword
474
- if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
475
- relevant_sentences.add(sentence.strip())
476
- initial_keywords.add(keyword.lower())
477
- sentence_added = True
478
-
479
- # 2. Check for range patterns (e.g., A1YU101-A1YU137)
480
- # The range pattern should be broad enough to capture the full code string within the range.
481
- range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
482
- range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
483
-
484
- for r_match in range_matches:
485
- start_code_str = r_match.group(1)
486
- end_code_str = r_match.group(2)
487
-
488
- # CRITICAL FIX: Use code_pattern.search for start_match and end_match
489
- start_match = code_pattern.search(start_code_str)
490
- end_match = code_pattern.search(end_code_str)
491
-
492
- if keyword_prefix and keyword_num is not None and start_match and end_match:
493
- start_prefix = start_match.group(1).lower()
494
- end_prefix = end_match.group(1).lower()
495
- start_num = int(start_match.group(2))
496
- end_num = int(end_match.group(2))
497
-
498
- # Check if the keyword's prefix matches and its number is within the range
499
- if keyword_prefix == start_prefix and \
500
- keyword_prefix == end_prefix and \
501
- start_num <= keyword_num <= end_num:
502
- relevant_sentences.add(sentence.strip())
503
- initial_keywords.add(start_code_str.lower())
504
- initial_keywords.add(end_code_str.lower())
505
- sentence_added = True
506
- break # Only need to find one matching range per sentence
507
-
508
- # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
509
- # to initial_keywords to ensure graph traversal from related terms.
510
- if sentence_added:
511
- for word in extract_codes(sentence):
512
- initial_keywords.add(word.lower())
513
-
514
-
515
- # Build word_to_sentences mapping for all sentences
516
- word_to_sentences = {}
517
- for sent in sentences:
518
- codes_in_sent = set(extract_codes(sent))
519
- for code in codes_in_sent:
520
- word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
521
-
522
-
523
- # Build the graph
524
- graph = {}
525
- for sent in sentences:
526
- codes = set(extract_codes(sent))
527
- for word1 in codes:
528
- word1_lower = word1.lower()
529
- graph.setdefault(word1_lower, set())
530
- for word2 in codes:
531
- word2_lower = word2.lower()
532
- if word1_lower != word2_lower:
533
- graph[word1_lower].add(word2_lower)
534
-
535
-
536
- # Perform BFS/graph traversal
537
- queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
538
- visited_words = set(initial_keywords)
539
-
540
- while queue:
541
- current_word, level = queue.pop(0)
542
- if level >= depth:
543
- continue
544
-
545
- relevant_sentences.update(word_to_sentences.get(current_word, []))
546
-
547
- for neighbor in graph.get(current_word, []):
548
- if neighbor not in visited_words:
549
- visited_words.add(neighbor)
550
- queue.append((neighbor, level + 1))
551
-
552
- final_sentences = set()
553
- for sentence in relevant_sentences:
554
- if not is_irrelevant_number_sequence(sentence):
555
- processed_sentence = remove_isolated_single_digits(sentence)
556
- if processed_sentence:
557
- final_sentences.add(processed_sentence)
558
-
559
- return "\n".join(sorted(list(final_sentences)))
560
-
561
-
562
-
563
- def get_contextual_sentences_DFS(text_content, keyword, depth=2):
564
- sentences = extract_sentences(text_content)
565
-
566
- # Build word-to-sentences mapping
567
- word_to_sentences = {}
568
- for sent in sentences:
569
- words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
570
- for word in words_in_sent:
571
- word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
572
-
573
- # Function to extract codes in a sentence
574
- def extract_codes(sentence):
575
- # Only codes like 'KSK1', 'MG272794', not pure numbers
576
- return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
577
-
578
- # DFS with priority based on distance to keyword and early stop if country found
579
- def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
580
- country = "unknown"
581
- if current_depth > max_depth:
582
- return country, False
583
-
584
- if current_word not in word_to_sentences:
585
- return country, False
586
-
587
- for sentence in word_to_sentences[current_word]:
588
- if sentence == parent_sentence:
589
- continue # avoid reusing the same sentence
590
-
591
- collected_sentences.add(sentence)
592
-
593
- #print("current_word:", current_word)
594
- small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
595
- #print(small_sen)
596
- country = model.get_country_from_text(small_sen)
597
- #print("small context country:", country)
598
- if country.lower() != "unknown":
599
- return country, True
600
- else:
601
- country = model.get_country_from_text(sentence)
602
- #print("full sentence country:", country)
603
- if country.lower() != "unknown":
604
- return country, True
605
-
606
- codes_in_sentence = extract_codes(sentence)
607
- idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
608
- if idx is None:
609
- continue
610
-
611
- sorted_children = sorted(
612
- [code for code in codes_in_sentence if code.lower() not in visited_words],
613
- key=lambda x: (abs(codes_in_sentence.index(x) - idx),
614
- 0 if codes_in_sentence.index(x) > idx else 1)
615
- )
616
-
617
- #print("sorted_children:", sorted_children)
618
- for child in sorted_children:
619
- child_lower = child.lower()
620
- if child_lower not in visited_words:
621
- visited_words.add(child_lower)
622
- country, should_stop = dfs_traverse(
623
- child_lower, current_depth + 1, max_depth,
624
- visited_words, collected_sentences, parent_sentence=sentence
625
- )
626
- if should_stop:
627
- return country, True
628
-
629
- return country, False
630
-
631
- # Begin DFS
632
- collected_sentences = set()
633
- visited_words = set([keyword.lower()])
634
- country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
635
-
636
- # Filter irrelevant sentences
637
- final_sentences = set()
638
- for sentence in collected_sentences:
639
- if not is_irrelevant_number_sequence(sentence):
640
- processed = remove_isolated_single_digits(sentence)
641
- if processed:
642
- final_sentences.add(processed)
643
- if not final_sentences:
644
- return country, text_content
645
- return country, "\n".join(sorted(list(final_sentences)))
646
-
647
- # Helper function for normalizing text for overlap comparison
648
- def normalize_for_overlap(s: str) -> str:
649
- s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
650
- s = re.sub(r'\s+', ' ', s).strip()
651
- return s
652
-
653
- def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
654
- if not text1: return text2
655
- if not text2: return text1
656
-
657
- # Case 1: text2 is fully contained in text1 or vice-versa
658
- if text2 in text1:
659
- return text1
660
- if text1 in text2:
661
- return text2
662
-
663
- # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
664
- # This is what your function was primarily designed for.
665
- # It looks for the overlap at the "junction" of text1 and text2.
666
-
667
- max_junction_overlap = 0
668
- for i in range(min(len(text1), len(text2)), 0, -1):
669
- suffix1 = text1[-i:]
670
- prefix2 = text2[:i]
671
- # Prioritize exact match, then normalized match
672
- if suffix1 == prefix2:
673
- max_junction_overlap = i
674
- break
675
- elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
676
- max_junction_overlap = i
677
- break # Take the first (longest) normalized match
678
-
679
- if max_junction_overlap > 0:
680
- merged_text = text1 + text2[max_junction_overlap:]
681
- return re.sub(r'\s+', ' ', merged_text).strip()
682
-
683
- # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
684
- # This addresses your specific test case where the overlap is at the very beginning of both strings.
685
- # This is often used when trying to deduplicate content that shares a common start.
686
-
687
- longest_common_prefix_len = 0
688
- min_len = min(len(text1), len(text2))
689
- for i in range(min_len):
690
- if text1[i] == text2[i]:
691
- longest_common_prefix_len = i + 1
692
- else:
693
- break
694
-
695
- # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
696
- # AND the remaining parts are distinct, then apply this merge.
697
- # This is a heuristic and might need fine-tuning.
698
- if longest_common_prefix_len > 0 and \
699
- text1[longest_common_prefix_len:].strip() and \
700
- text2[longest_common_prefix_len:].strip():
701
-
702
- # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
703
- # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
704
- # common prefix is "Hi, I am Vy."
705
- # Remaining text1: " Nice to meet you."
706
- # Remaining text2: " Goodbye Vy."
707
- # So we merge common_prefix + remaining_text1 + remaining_text2
708
-
709
- common_prefix_str = text1[:longest_common_prefix_len]
710
- remainder_text1 = text1[longest_common_prefix_len:]
711
- remainder_text2 = text2[longest_common_prefix_len:]
712
-
713
- merged_text = common_prefix_str + remainder_text1 + remainder_text2
714
- return re.sub(r'\s+', ' ', merged_text).strip()
715
-
716
-
717
- # If neither specific overlap type is found, just concatenate
718
- merged_text = text1 + text2
719
- return re.sub(r'\s+', ' ', merged_text).strip()
720
-
721
- from docx import Document
722
- from pipeline import upload_file_to_drive
723
- # def save_text_to_docx(text_content: str, file_path: str):
724
- # """
725
- # Saves a given text string into a .docx file.
726
-
727
- # Args:
728
- # text_content (str): The text string to save.
729
- # file_path (str): The full path including the filename where the .docx file will be saved.
730
- # Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
731
- # """
732
- # try:
733
- # document = Document()
734
-
735
- # # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
736
- # for paragraph_text in text_content.split('\n'):
737
- # document.add_paragraph(paragraph_text)
738
-
739
- # document.save(file_path)
740
- # print(f"Text successfully saved to '{file_path}'")
741
- # except Exception as e:
742
- # print(f"Error saving text to docx file: {e}")
743
- # def save_text_to_docx(text_content: str, filename: str, drive_folder_id: str):
744
- # """
745
- # Saves a given text string into a .docx file locally, then uploads to Google Drive.
746
-
747
- # Args:
748
- # text_content (str): The text string to save.
749
- # filename (str): The target .docx file name, e.g. 'BRU18_merged_document.docx'.
750
- # drive_folder_id (str): Google Drive folder ID where to upload the file.
751
- # """
752
- # try:
753
- # # ✅ Save to temporary local path first
754
- # print("file name: ", filename)
755
- # print("length text content: ", len(text_content))
756
- # local_path = os.path.join(tempfile.gettempdir(), filename)
757
- # document = Document()
758
- # for paragraph_text in text_content.split('\n'):
759
- # document.add_paragraph(paragraph_text)
760
- # document.save(local_path)
761
- # print(f"✅ Text saved locally to: {local_path}")
762
-
763
- # # ✅ Upload to Drive
764
- # pipeline.upload_file_to_drive(local_path, filename, drive_folder_id)
765
- # print(f"✅ Uploaded '{filename}' to Google Drive folder ID: {drive_folder_id}")
766
-
767
- # except Exception as e:
768
- # print(f"❌ Error saving or uploading DOCX: {e}")
769
- def save_text_to_docx(text_content: str, full_local_path: str):
770
- document = Document()
771
- for paragraph_text in text_content.split('\n'):
772
- document.add_paragraph(paragraph_text)
773
- document.save(full_local_path)
774
- print(f"✅ Saved DOCX locally: {full_local_path}")
775
-
776
-
777
-
778
- '''2 scenerios:
779
- - quick look then found then deepdive and directly get location then stop
780
- - quick look then found then deepdive but not find location then hold the related words then
781
- look another files iteratively for each related word and find location and stop'''
782
- def extract_context(text, keyword, window=500):
783
- # firstly try accession number
784
- code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
785
-
786
- # Attempt to parse the keyword into its prefix and numerical part using re.search
787
- keyword_match = code_pattern.search(keyword)
788
-
789
- keyword_prefix = None
790
- keyword_num = None
791
-
792
- if keyword_match:
793
- keyword_prefix = keyword_match.group(1).lower()
794
- keyword_num = int(keyword_match.group(2))
795
- text = text.lower()
796
- idx = text.find(keyword.lower())
797
- if idx == -1:
798
- if keyword_prefix:
799
- idx = text.find(keyword_prefix)
800
- if idx == -1:
801
- return "Sample ID not found."
802
- return text[max(0, idx-window): idx+window]
803
- return text[max(0, idx-window): idx+window]
804
- def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
805
- cache = {}
806
- country = "unknown"
807
- output = ""
808
- tem_output, small_output = "",""
809
- keyword_appear = (False,"")
810
- keywords = []
811
- if isolate: keywords.append(isolate)
812
- if accession: keywords.append(accession)
813
- for f in filePaths:
814
- # scenerio 1: direct location: truncate the context and then use qa model?
815
- if keywords:
816
- for keyword in keywords:
817
- text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
818
- if keyword in final_input:
819
- context = extract_context(final_input, keyword)
820
- # quick look if country already in context and if yes then return
821
- country = model.get_country_from_text(context)
822
- if country != "unknown":
823
- return country, context, final_input
824
- else:
825
- country = model.get_country_from_text(final_input)
826
- if country != "unknown":
827
- return country, context, final_input
828
- else: # might be cross-ref
829
- keyword_appear = (True, f)
830
- cache[f] = context
831
- small_output = merge_texts_skipping_overlap(output, context) + "\n"
832
- chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
833
- countryBFS = model.get_country_from_text(chunkBFS)
834
- countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
835
- output = merge_texts_skipping_overlap(output, final_input)
836
- if countryDFS != "unknown" and countryBFS != "unknown":
837
- if len(chunkDFS) <= len(chunkBFS):
838
- return countryDFS, chunkDFS, output
839
- else:
840
- return countryBFS, chunkBFS, output
841
- else:
842
- if countryDFS != "unknown":
843
- return countryDFS, chunkDFS, output
844
- if countryBFS != "unknown":
845
- return countryBFS, chunkBFS, output
846
- else:
847
- # scenerio 2:
848
- '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
849
- but if we look at file 1 first then maybe we can have lookup dict which country
850
- such as Thailand as the key and its re'''
851
- cache[f] = final_input
852
- if keyword_appear[0] == True:
853
- for c in cache:
854
- if c!=keyword_appear[1]:
855
- if cache[c].lower() not in output.lower():
856
- output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
857
- chunkBFS = get_contextual_sentences_BFS(output, keyword)
858
- countryBFS = model.get_country_from_text(chunkBFS)
859
- countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
860
- if countryDFS != "unknown" and countryBFS != "unknown":
861
- if len(chunkDFS) <= len(chunkBFS):
862
- return countryDFS, chunkDFS, output
863
- else:
864
- return countryBFS, chunkBFS, output
865
- else:
866
- if countryDFS != "unknown":
867
- return countryDFS, chunkDFS, output
868
- if countryBFS != "unknown":
869
- return countryBFS, chunkBFS, output
870
- else:
871
- if cache[f].lower() not in output.lower():
872
- output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
873
- if len(output) == 0 or keyword_appear[0]==False:
874
- for c in cache:
875
- if cache[c].lower() not in output.lower():
876
- output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
877
  return country, "", output
 
1
+ import re
2
+ import os
3
+ #import streamlit as st
4
+ import subprocess
5
+ import re
6
+ from Bio import Entrez
7
+ from docx import Document
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ #from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ import pandas as pd
19
+ import model
20
+ import pipeline
21
+ import tempfile
22
+ import nltk
23
+ nltk.download('punkt_tab')
24
+ def download_excel_file(url, save_path="temp.xlsx"):
25
+ if "view.officeapps.live.com" in url:
26
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
27
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
28
+ response = requests.get(real_url)
29
+ with open(save_path, "wb") as f:
30
+ f.write(response.content)
31
+ return save_path
32
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
33
+ response = requests.get(url)
34
+ response.raise_for_status() # Raises error if download fails
35
+ with open(save_path, "wb") as f:
36
+ f.write(response.content)
37
+ print(len(response.content))
38
+ return save_path
39
+ else:
40
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
41
+ return url
42
+
43
+ from pathlib import Path
44
+ import pandas as pd
45
+
46
+ def process_file(link, saveFolder):
47
+ """Returns (file_type, full_path, name) for a given link."""
48
+ name = Path(link).name
49
+ ext = Path(name).suffix.lower()
50
+ file_path = Path(saveFolder) / name
51
+
52
+ # If it's already in saveFolder, update link to local path
53
+ if file_path.is_file():
54
+ link = str(file_path)
55
+
56
+ return ext, link, file_path
57
+
58
+ import asyncio
59
+ import aiohttp
60
+ _html_cache = {}
61
+
62
+ async def async_fetch_html(link: str, timeout: int = 15) -> str:
63
+ """Fetch HTML asynchronously with caching."""
64
+ if link in _html_cache:
65
+ return _html_cache[link]
66
+
67
+ try:
68
+ async with aiohttp.ClientSession() as session:
69
+ async with session.get(link, timeout=timeout) as resp:
70
+ if resp.status != 200:
71
+ print(f"⚠️ Failed {link} ({resp.status})")
72
+ return ""
73
+ html_content = await resp.text()
74
+ _html_cache[link] = html_content
75
+ return html_content
76
+ except Exception as e:
77
+ print(f"❌ async_fetch_html error for {link}: {e}")
78
+ return ""
79
+
80
+ async def ensure_local_file(link: str, saveFolder: str) -> str:
81
+ """Ensure file is available locally (Drive or web). Returns local path."""
82
+ name = link.split("/")[-1]
83
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
84
+
85
+ if os.path.exists(local_temp_path):
86
+ return local_temp_path
87
+
88
+ # Try Drive first (blocking → offload)
89
+ file_id = await asyncio.to_thread(pipeline.find_drive_file, name, saveFolder)
90
+ if file_id:
91
+ await asyncio.to_thread(pipeline.download_file_from_drive, name, saveFolder, local_temp_path)
92
+ else:
93
+ # Web download asynchronously
94
+ async with aiohttp.ClientSession() as session:
95
+ async with session.get(link, timeout=20) as resp:
96
+ resp.raise_for_status()
97
+ content = await resp.read()
98
+ with open(local_temp_path, "wb") as f:
99
+ f.write(content)
100
+ # Upload back to Drive (offload)
101
+ await asyncio.to_thread(pipeline.upload_file_to_drive, local_temp_path, name, saveFolder)
102
+
103
+ return local_temp_path
104
+
105
+ async def async_extract_text(link, saveFolder):
106
+ try:
107
+ if link.endswith(".pdf"):
108
+ local_path = await ensure_local_file(link, saveFolder)
109
+ return await asyncio.to_thread(lambda: pdf.PDFFast(local_path, saveFolder).extract_text())
110
+
111
+ elif link.endswith((".doc", ".docx")):
112
+ local_path = await ensure_local_file(link, saveFolder)
113
+ return await asyncio.to_thread(lambda: wordDoc.WordDocFast(local_path, saveFolder).extractText())
114
+
115
+ elif link.endswith((".xls", ".xlsx")):
116
+ return ""
117
+
118
+ elif link.startswith("http") or "html" in link:
119
+ html_content = await async_fetch_html(link)
120
+ html = extractHTML.HTML(htmlContent=html_content, htmlLink=link, htmlFile="")
121
+ # If you implement async_getListSection, call it here
122
+ if hasattr(html, "async_getListSection"):
123
+ article_text = await html.async_getListSection()
124
+ else:
125
+ # fallback: run sync getListSection in a thread
126
+ article_text = await asyncio.to_thread(html.getListSection)
127
+
128
+ if not article_text:
129
+ metadata_text = html.fetch_crossref_metadata(link)
130
+ if metadata_text:
131
+ article_text = html.mergeTextInJson(metadata_text)
132
+ return article_text
133
+
134
+ else:
135
+ return ""
136
+ except Exception as e:
137
+ print(f"❌ async_extract_text failed for {link}: {e}")
138
+ return ""
139
+
140
+
141
+ def extract_text(link,saveFolder):
142
+ try:
143
+ text = ""
144
+ name = link.split("/")[-1]
145
+ print("name: ", name)
146
+ #file_path = Path(saveFolder) / name
147
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
148
+ print("this is local temp path: ", local_temp_path)
149
+ if os.path.exists(local_temp_path):
150
+ input_to_class = local_temp_path
151
+ print("exist")
152
+ else:
153
+ #input_to_class = link # Let the class handle downloading
154
+ # 1. Check if file exists in shared Google Drive folder
155
+ file_id = pipeline.find_drive_file(name, saveFolder)
156
+ if file_id:
157
+ print("📥 Downloading from Google Drive...")
158
+ pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
159
+ else:
160
+ print("🌐 Downloading from web link...")
161
+ response = requests.get(link)
162
+ with open(local_temp_path, 'wb') as f:
163
+ f.write(response.content)
164
+ print("✅ Saved locally.")
165
+
166
+ # 2. Upload to Drive so it's available for later
167
+ pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
168
+
169
+ input_to_class = local_temp_path
170
+ print(input_to_class)
171
+ # pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
172
+ # pdf
173
+ if link.endswith(".pdf"):
174
+ # if file_path.is_file():
175
+ # link = saveFolder + "/" + name
176
+ # print("File exists.")
177
+ #p = pdf.PDF(local_temp_path, saveFolder)
178
+ print("inside pdf and input to class: ", input_to_class)
179
+ print("save folder in extract text: ", saveFolder)
180
+ #p = pdf.PDF(input_to_class, saveFolder)
181
+ #p = pdf.PDF(link,saveFolder)
182
+ #text = p.extractTextWithPDFReader()
183
+ #text = p.extractText()
184
+ p = pdf.PDFFast(input_to_class, saveFolder)
185
+ text = p.extract_text()
186
+
187
+ print("len text from pdf:")
188
+ print(len(text))
189
+ #text_exclude_table = p.extract_text_excluding_tables()
190
+ # worddoc
191
+ elif link.endswith(".doc") or link.endswith(".docx"):
192
+ #d = wordDoc.wordDoc(local_temp_path,saveFolder)
193
+ # d = wordDoc.wordDoc(input_to_class,saveFolder)
194
+ # text = d.extractTextByPage()
195
+ d = wordDoc.WordDocFast(input_to_class, saveFolder)
196
+ text = d.extractText()
197
+
198
+ # html
199
+ else:
200
+ if link.split(".")[-1].lower() not in "xlsx":
201
+ if "http" in link or "html" in link:
202
+ print("html link: ", link)
203
+ html = extractHTML.HTML("",link)
204
+ text = html.getListSection() # the text already clean
205
+ print("len text html: ")
206
+ print(len(text))
207
+ # Cleanup: delete the local temp file
208
+ if name:
209
+ if os.path.exists(local_temp_path):
210
+ os.remove(local_temp_path)
211
+ print(f"🧹 Deleted local temp file: {local_temp_path}")
212
+ print("done extract text")
213
+ except:
214
+ text = ""
215
+ return text
216
+
217
+ def extract_table(link,saveFolder):
218
+ try:
219
+ table = []
220
+ name = link.split("/")[-1]
221
+ #file_path = Path(saveFolder) / name
222
+ local_temp_path = os.path.join(tempfile.gettempdir(), name)
223
+ if os.path.exists(local_temp_path):
224
+ input_to_class = local_temp_path
225
+ print("exist")
226
+ else:
227
+ #input_to_class = link # Let the class handle downloading
228
+ # 1. Check if file exists in shared Google Drive folder
229
+ file_id = pipeline.find_drive_file(name, saveFolder)
230
+ if file_id:
231
+ print("📥 Downloading from Google Drive...")
232
+ pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
233
+ else:
234
+ print("🌐 Downloading from web link...")
235
+ response = requests.get(link)
236
+ with open(local_temp_path, 'wb') as f:
237
+ f.write(response.content)
238
+ print("✅ Saved locally.")
239
+
240
+ # 2. Upload to Drive so it's available for later
241
+ pipeline.upload_file_to_drive(local_temp_path, name, saveFolder)
242
+
243
+ input_to_class = local_temp_path
244
+ print(input_to_class)
245
+ #pipeline.download_file_from_drive(name, saveFolder, local_temp_path)
246
+ # pdf
247
+ if link.endswith(".pdf"):
248
+ # if file_path.is_file():
249
+ # link = saveFolder + "/" + name
250
+ # print("File exists.")
251
+ #p = pdf.PDF(local_temp_path,saveFolder)
252
+ p = pdf.PDF(input_to_class,saveFolder)
253
+ table = p.extractTable()
254
+ # worddoc
255
+ elif link.endswith(".doc") or link.endswith(".docx"):
256
+ #d = wordDoc.wordDoc(local_temp_path,saveFolder)
257
+ # d = wordDoc.wordDoc(input_to_class,saveFolder)
258
+ # table = d.extractTableAsList()
259
+ d = wordDoc.WordDocFast(input_to_class, saveFolder)
260
+ table = d.extractTableAsList()
261
+ # excel
262
+ elif link.split(".")[-1].lower() in "xlsx":
263
+ # download excel file if it not downloaded yet
264
+ savePath = saveFolder +"/"+ link.split("/")[-1]
265
+ excelPath = download_excel_file(link, savePath)
266
+ try:
267
+ #xls = pd.ExcelFile(excelPath)
268
+ xls = pd.ExcelFile(local_temp_path)
269
+ table_list = []
270
+ for sheet_name in xls.sheet_names:
271
+ df = pd.read_excel(xls, sheet_name=sheet_name)
272
+ cleaned_table = df.fillna("").astype(str).values.tolist()
273
+ table_list.append(cleaned_table)
274
+ table = table_list
275
+ except Exception as e:
276
+ print("❌ Failed to extract tables from Excel:", e)
277
+ # html
278
+ elif "http" in link or "html" in link:
279
+ html = extractHTML.HTML("",link)
280
+ table = html.extractTable() # table is a list
281
+ table = clean_tables_format(table)
282
+ # Cleanup: delete the local temp file
283
+ if os.path.exists(local_temp_path):
284
+ os.remove(local_temp_path)
285
+ print(f"🧹 Deleted local temp file: {local_temp_path}")
286
+ except:
287
+ table = []
288
+ return table
289
+
290
+ def clean_tables_format(tables):
291
+ """
292
+ Ensures all tables are in consistent format: List[List[List[str]]]
293
+ Cleans by:
294
+ - Removing empty strings and rows
295
+ - Converting all cells to strings
296
+ - Handling DataFrames and list-of-lists
297
+ """
298
+ cleaned = []
299
+ if tables:
300
+ for table in tables:
301
+ standardized = []
302
+
303
+ # Case 1: Pandas DataFrame
304
+ if isinstance(table, pd.DataFrame):
305
+ table = table.fillna("").astype(str).values.tolist()
306
+
307
+ # Case 2: List of Lists
308
+ if isinstance(table, list) and all(isinstance(row, list) for row in table):
309
+ for row in table:
310
+ filtered_row = [str(cell).strip() for cell in row if str(cell).strip()]
311
+ if filtered_row:
312
+ standardized.append(filtered_row)
313
+
314
+ if standardized:
315
+ cleaned.append(standardized)
316
+
317
+ return cleaned
318
+
319
+ import json
320
+ def normalize_text_for_comparison(s: str) -> str:
321
+ """
322
+ Normalizes text for robust comparison by:
323
+ 1. Converting to lowercase.
324
+ 2. Replacing all types of newlines with a single consistent newline (\n).
325
+ 3. Removing extra spaces (e.g., multiple spaces, leading/trailing spaces on lines).
326
+ 4. Stripping leading/trailing whitespace from the entire string.
327
+ """
328
+ s = s.lower()
329
+ s = s.replace('\r\n', '\n') # Handle Windows newlines
330
+ s = s.replace('\r', '\n') # Handle Mac classic newlines
331
+
332
+ # Replace sequences of whitespace (including multiple newlines) with a single space
333
+ # This might be too aggressive if you need to preserve paragraph breaks,
334
+ # but good for exact word-sequence matching.
335
+ s = re.sub(r'\s+', ' ', s)
336
+
337
+ return s.strip()
338
+ def merge_text_and_tables(text, tables, max_tokens=12000, keep_tables=True, tokenizer="cl100k_base", accession_id=None, isolate=None):
339
+ """
340
+ Merge cleaned text and table into one string for LLM input.
341
+ - Avoids duplicating tables already in text
342
+ - Extracts only relevant rows from large tables
343
+ - Skips or saves oversized tables
344
+ """
345
+ import importlib
346
+ json = importlib.import_module("json")
347
+
348
+ def estimate_tokens(text_str):
349
+ try:
350
+ enc = tiktoken.get_encoding(tokenizer)
351
+ return len(enc.encode(text_str))
352
+ except:
353
+ return len(text_str) // 4 # Fallback estimate
354
+
355
+ def is_table_relevant(table, keywords, accession_id=None):
356
+ flat = " ".join(" ".join(row).lower() for row in table)
357
+ if accession_id and accession_id.lower() in flat:
358
+ return True
359
+ return any(kw.lower() in flat for kw in keywords)
360
+ preview, preview1 = "",""
361
+ llm_input = "## Document Text\n" + text.strip() + "\n"
362
+ clean_text = normalize_text_for_comparison(text)
363
+
364
+ if tables:
365
+ for idx, table in enumerate(tables):
366
+ keywords = ["province","district","region","village","location", "country", "region", "origin", "ancient", "modern"]
367
+ if accession_id: keywords += [accession_id.lower()]
368
+ if isolate: keywords += [isolate.lower()]
369
+ if is_table_relevant(table, keywords, accession_id):
370
+ if len(table) > 0:
371
+ for tab in table:
372
+ preview = " ".join(tab) if tab else ""
373
+ preview1 = "\n".join(tab) if tab else ""
374
+ clean_preview = normalize_text_for_comparison(preview)
375
+ clean_preview1 = normalize_text_for_comparison(preview1)
376
+ if clean_preview not in clean_text:
377
+ if clean_preview1 not in clean_text:
378
+ table_str = json.dumps([tab], indent=2)
379
+ llm_input += f"## Table {idx+1}\n{table_str}\n"
380
+ return llm_input.strip()
381
+
382
+ def preprocess_document(link, saveFolder, accession=None, isolate=None, article_text=None):
383
+ if article_text:
384
+ print("article text already available")
385
+ text = article_text
386
+ else:
387
+ try:
388
+ print("start preprocess and extract text")
389
+ text = extract_text(link, saveFolder)
390
+ except: text = ""
391
+ try:
392
+ print("extract table start")
393
+ success, the_output = pipeline.run_with_timeout(extract_table,args=(link,saveFolder),timeout=10)
394
+ print("Returned from timeout logic")
395
+ if success:
396
+ tables = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
397
+ print("yes succeed for extract table")
398
+ else:
399
+ print("not suceed etxract table")
400
+ tables = []
401
+ #tables = extract_table(link, saveFolder)
402
+ except: tables = []
403
+ if accession: accession = accession
404
+ if isolate: isolate = isolate
405
+ try:
406
+ # print("merge text and table start")
407
+ # success, the_output = pipeline.run_with_timeout(merge_text_and_tables,kwargs={"text":text,"tables":tables,"accession_id":accession, "isolate":isolate},timeout=30)
408
+ # print("Returned from timeout logic")
409
+ # if success:
410
+ # final_input = the_output#data_preprocess.merge_texts_skipping_overlap(all_output, final_input_link)
411
+ # print("yes succeed")
412
+ # else:
413
+ # print("not suceed")
414
+ print("just merge text and tables")
415
+ final_input = text + ", ".join(tables)
416
+ #final_input = pipeline.timeout(merge_text_and_tables(text, tables, max_tokens=12000, accession_id=accession, isolate=isolate)
417
+ except:
418
+ print("no succeed here in preprocess docu")
419
+ final_input = ""
420
+ return text, tables, final_input
421
+
422
+ def extract_sentences(text):
423
+ sentences = re.split(r'(?<=[.!?])\s+', text)
424
+ return [s.strip() for s in sentences if s.strip()]
425
+
426
+ def is_irrelevant_number_sequence(text):
427
+ if re.search(r'\b[A-Z]{2,}\d+\b|\b[A-Za-z]+\s+\d+\b', text, re.IGNORECASE):
428
+ return False
429
+ word_count = len(re.findall(r'\b[A-Za-z]{2,}\b', text))
430
+ number_count = len(re.findall(r'\b\d[\d\.]*\b', text))
431
+ total_tokens = len(re.findall(r'\S+', text))
432
+ if total_tokens > 0 and (word_count / total_tokens < 0.2) and (number_count / total_tokens > 0.5):
433
+ return True
434
+ elif re.fullmatch(r'(\d+(\.\d+)?\s*)+', text.strip()):
435
+ return True
436
+ return False
437
+
438
+ def remove_isolated_single_digits(sentence):
439
+ tokens = sentence.split()
440
+ filtered_tokens = []
441
+ for token in tokens:
442
+ if token == '0' or token == '1':
443
+ pass
444
+ else:
445
+ filtered_tokens.append(token)
446
+ return ' '.join(filtered_tokens).strip()
447
+
448
+ def get_contextual_sentences_BFS(text_content, keyword, depth=2):
449
+ def extract_codes(sentence):
450
+ # Match codes like 'A1YU101', 'KM1', 'MO6' — at least 2 letters + numbers
451
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
452
+ sentences = extract_sentences(text_content)
453
+ relevant_sentences = set()
454
+ initial_keywords = set()
455
+
456
+ # Define a regex to capture codes like A1YU101 or KM1
457
+ # This pattern looks for an alphanumeric sequence followed by digits at the end of the string
458
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
459
+
460
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
461
+ keyword_match = code_pattern.search(keyword)
462
+
463
+ keyword_prefix = None
464
+ keyword_num = None
465
+
466
+ if keyword_match:
467
+ keyword_prefix = keyword_match.group(1).lower()
468
+ keyword_num = int(keyword_match.group(2))
469
+
470
+ for sentence in sentences:
471
+ sentence_added = False
472
+
473
+ # 1. Check for exact match of the keyword
474
+ if re.search(r'\b' + re.escape(keyword) + r'\b', sentence, re.IGNORECASE):
475
+ relevant_sentences.add(sentence.strip())
476
+ initial_keywords.add(keyword.lower())
477
+ sentence_added = True
478
+
479
+ # 2. Check for range patterns (e.g., A1YU101-A1YU137)
480
+ # The range pattern should be broad enough to capture the full code string within the range.
481
+ range_matches = re.finditer(r'([A-Z0-9]+-\d+)', sentence, re.IGNORECASE) # More specific range pattern if needed, or rely on full code pattern below
482
+ range_matches = re.finditer(r'([A-Z0-9]+\d+)-([A-Z0-9]+\d+)', sentence, re.IGNORECASE) # This is the more robust range pattern
483
+
484
+ for r_match in range_matches:
485
+ start_code_str = r_match.group(1)
486
+ end_code_str = r_match.group(2)
487
+
488
+ # CRITICAL FIX: Use code_pattern.search for start_match and end_match
489
+ start_match = code_pattern.search(start_code_str)
490
+ end_match = code_pattern.search(end_code_str)
491
+
492
+ if keyword_prefix and keyword_num is not None and start_match and end_match:
493
+ start_prefix = start_match.group(1).lower()
494
+ end_prefix = end_match.group(1).lower()
495
+ start_num = int(start_match.group(2))
496
+ end_num = int(end_match.group(2))
497
+
498
+ # Check if the keyword's prefix matches and its number is within the range
499
+ if keyword_prefix == start_prefix and \
500
+ keyword_prefix == end_prefix and \
501
+ start_num <= keyword_num <= end_num:
502
+ relevant_sentences.add(sentence.strip())
503
+ initial_keywords.add(start_code_str.lower())
504
+ initial_keywords.add(end_code_str.lower())
505
+ sentence_added = True
506
+ break # Only need to find one matching range per sentence
507
+
508
+ # 3. If the sentence was added due to exact match or range, add all its alphanumeric codes
509
+ # to initial_keywords to ensure graph traversal from related terms.
510
+ if sentence_added:
511
+ for word in extract_codes(sentence):
512
+ initial_keywords.add(word.lower())
513
+
514
+
515
+ # Build word_to_sentences mapping for all sentences
516
+ word_to_sentences = {}
517
+ for sent in sentences:
518
+ codes_in_sent = set(extract_codes(sent))
519
+ for code in codes_in_sent:
520
+ word_to_sentences.setdefault(code.lower(), set()).add(sent.strip())
521
+
522
+
523
+ # Build the graph
524
+ graph = {}
525
+ for sent in sentences:
526
+ codes = set(extract_codes(sent))
527
+ for word1 in codes:
528
+ word1_lower = word1.lower()
529
+ graph.setdefault(word1_lower, set())
530
+ for word2 in codes:
531
+ word2_lower = word2.lower()
532
+ if word1_lower != word2_lower:
533
+ graph[word1_lower].add(word2_lower)
534
+
535
+
536
+ # Perform BFS/graph traversal
537
+ queue = [(k, 0) for k in initial_keywords if k in word_to_sentences]
538
+ visited_words = set(initial_keywords)
539
+
540
+ while queue:
541
+ current_word, level = queue.pop(0)
542
+ if level >= depth:
543
+ continue
544
+
545
+ relevant_sentences.update(word_to_sentences.get(current_word, []))
546
+
547
+ for neighbor in graph.get(current_word, []):
548
+ if neighbor not in visited_words:
549
+ visited_words.add(neighbor)
550
+ queue.append((neighbor, level + 1))
551
+
552
+ final_sentences = set()
553
+ for sentence in relevant_sentences:
554
+ if not is_irrelevant_number_sequence(sentence):
555
+ processed_sentence = remove_isolated_single_digits(sentence)
556
+ if processed_sentence:
557
+ final_sentences.add(processed_sentence)
558
+
559
+ return "\n".join(sorted(list(final_sentences)))
560
+
561
+
562
+
563
+ def get_contextual_sentences_DFS(text_content, keyword, depth=2):
564
+ sentences = extract_sentences(text_content)
565
+
566
+ # Build word-to-sentences mapping
567
+ word_to_sentences = {}
568
+ for sent in sentences:
569
+ words_in_sent = set(re.findall(r'\b[A-Za-z0-9\-_\/]+\b', sent))
570
+ for word in words_in_sent:
571
+ word_to_sentences.setdefault(word.lower(), set()).add(sent.strip())
572
+
573
+ # Function to extract codes in a sentence
574
+ def extract_codes(sentence):
575
+ # Only codes like 'KSK1', 'MG272794', not pure numbers
576
+ return [code for code in re.findall(r'\b[A-Z]{2,}[0-9]+\b', sentence, re.IGNORECASE)]
577
+
578
+ # DFS with priority based on distance to keyword and early stop if country found
579
+ def dfs_traverse(current_word, current_depth, max_depth, visited_words, collected_sentences, parent_sentence=None):
580
+ country = "unknown"
581
+ if current_depth > max_depth:
582
+ return country, False
583
+
584
+ if current_word not in word_to_sentences:
585
+ return country, False
586
+
587
+ for sentence in word_to_sentences[current_word]:
588
+ if sentence == parent_sentence:
589
+ continue # avoid reusing the same sentence
590
+
591
+ collected_sentences.add(sentence)
592
+
593
+ #print("current_word:", current_word)
594
+ small_sen = extract_context(sentence, current_word, int(len(sentence) / 4))
595
+ #print(small_sen)
596
+ country = model.get_country_from_text(small_sen)
597
+ #print("small context country:", country)
598
+ if country.lower() != "unknown":
599
+ return country, True
600
+ else:
601
+ country = model.get_country_from_text(sentence)
602
+ #print("full sentence country:", country)
603
+ if country.lower() != "unknown":
604
+ return country, True
605
+
606
+ codes_in_sentence = extract_codes(sentence)
607
+ idx = next((i for i, code in enumerate(codes_in_sentence) if code.lower() == current_word.lower()), None)
608
+ if idx is None:
609
+ continue
610
+
611
+ sorted_children = sorted(
612
+ [code for code in codes_in_sentence if code.lower() not in visited_words],
613
+ key=lambda x: (abs(codes_in_sentence.index(x) - idx),
614
+ 0 if codes_in_sentence.index(x) > idx else 1)
615
+ )
616
+
617
+ #print("sorted_children:", sorted_children)
618
+ for child in sorted_children:
619
+ child_lower = child.lower()
620
+ if child_lower not in visited_words:
621
+ visited_words.add(child_lower)
622
+ country, should_stop = dfs_traverse(
623
+ child_lower, current_depth + 1, max_depth,
624
+ visited_words, collected_sentences, parent_sentence=sentence
625
+ )
626
+ if should_stop:
627
+ return country, True
628
+
629
+ return country, False
630
+
631
+ # Begin DFS
632
+ collected_sentences = set()
633
+ visited_words = set([keyword.lower()])
634
+ country, status = dfs_traverse(keyword.lower(), 0, depth, visited_words, collected_sentences)
635
+
636
+ # Filter irrelevant sentences
637
+ final_sentences = set()
638
+ for sentence in collected_sentences:
639
+ if not is_irrelevant_number_sequence(sentence):
640
+ processed = remove_isolated_single_digits(sentence)
641
+ if processed:
642
+ final_sentences.add(processed)
643
+ if not final_sentences:
644
+ return country, text_content
645
+ return country, "\n".join(sorted(list(final_sentences)))
646
+
647
+ # Helper function for normalizing text for overlap comparison
648
+ def normalize_for_overlap(s: str) -> str:
649
+ s = re.sub(r'[^a-zA-Z0-9\s]', ' ', s).lower()
650
+ s = re.sub(r'\s+', ' ', s).strip()
651
+ return s
652
+
653
+ def merge_texts_skipping_overlap(text1: str, text2: str) -> str:
654
+ if not text1: return text2
655
+ if not text2: return text1
656
+
657
+ # Case 1: text2 is fully contained in text1 or vice-versa
658
+ if text2 in text1:
659
+ return text1
660
+ if text1 in text2:
661
+ return text2
662
+
663
+ # --- Option 1: Original behavior (suffix of text1, prefix of text2) ---
664
+ # This is what your function was primarily designed for.
665
+ # It looks for the overlap at the "junction" of text1 and text2.
666
+
667
+ max_junction_overlap = 0
668
+ for i in range(min(len(text1), len(text2)), 0, -1):
669
+ suffix1 = text1[-i:]
670
+ prefix2 = text2[:i]
671
+ # Prioritize exact match, then normalized match
672
+ if suffix1 == prefix2:
673
+ max_junction_overlap = i
674
+ break
675
+ elif normalize_for_overlap(suffix1) == normalize_for_overlap(prefix2):
676
+ max_junction_overlap = i
677
+ break # Take the first (longest) normalized match
678
+
679
+ if max_junction_overlap > 0:
680
+ merged_text = text1 + text2[max_junction_overlap:]
681
+ return re.sub(r'\s+', ' ', merged_text).strip()
682
+
683
+ # --- Option 2: Longest Common Prefix (for cases like "Hi, I am Vy.") ---
684
+ # This addresses your specific test case where the overlap is at the very beginning of both strings.
685
+ # This is often used when trying to deduplicate content that shares a common start.
686
+
687
+ longest_common_prefix_len = 0
688
+ min_len = min(len(text1), len(text2))
689
+ for i in range(min_len):
690
+ if text1[i] == text2[i]:
691
+ longest_common_prefix_len = i + 1
692
+ else:
693
+ break
694
+
695
+ # If a common prefix is found AND it's a significant portion (e.g., more than a few chars)
696
+ # AND the remaining parts are distinct, then apply this merge.
697
+ # This is a heuristic and might need fine-tuning.
698
+ if longest_common_prefix_len > 0 and \
699
+ text1[longest_common_prefix_len:].strip() and \
700
+ text2[longest_common_prefix_len:].strip():
701
+
702
+ # Only merge this way if the remaining parts are not empty (i.e., not exact duplicates)
703
+ # For "Hi, I am Vy. Nice to meet you." and "Hi, I am Vy. Goodbye Vy."
704
+ # common prefix is "Hi, I am Vy."
705
+ # Remaining text1: " Nice to meet you."
706
+ # Remaining text2: " Goodbye Vy."
707
+ # So we merge common_prefix + remaining_text1 + remaining_text2
708
+
709
+ common_prefix_str = text1[:longest_common_prefix_len]
710
+ remainder_text1 = text1[longest_common_prefix_len:]
711
+ remainder_text2 = text2[longest_common_prefix_len:]
712
+
713
+ merged_text = common_prefix_str + remainder_text1 + remainder_text2
714
+ return re.sub(r'\s+', ' ', merged_text).strip()
715
+
716
+
717
+ # If neither specific overlap type is found, just concatenate
718
+ merged_text = text1 + text2
719
+ return re.sub(r'\s+', ' ', merged_text).strip()
720
+
721
+ from docx import Document
722
+ from pipeline import upload_file_to_drive
723
+ # def save_text_to_docx(text_content: str, file_path: str):
724
+ # """
725
+ # Saves a given text string into a .docx file.
726
+
727
+ # Args:
728
+ # text_content (str): The text string to save.
729
+ # file_path (str): The full path including the filename where the .docx file will be saved.
730
+ # Example: '/content/drive/MyDrive/CollectData/Examples/test/SEA_1234/merged_document.docx'
731
+ # """
732
+ # try:
733
+ # document = Document()
734
+
735
+ # # Add the entire text as a single paragraph, or split by newlines for multiple paragraphs
736
+ # for paragraph_text in text_content.split('\n'):
737
+ # document.add_paragraph(paragraph_text)
738
+
739
+ # document.save(file_path)
740
+ # print(f"Text successfully saved to '{file_path}'")
741
+ # except Exception as e:
742
+ # print(f"Error saving text to docx file: {e}")
743
+ # def save_text_to_docx(text_content: str, filename: str, drive_folder_id: str):
744
+ # """
745
+ # Saves a given text string into a .docx file locally, then uploads to Google Drive.
746
+
747
+ # Args:
748
+ # text_content (str): The text string to save.
749
+ # filename (str): The target .docx file name, e.g. 'BRU18_merged_document.docx'.
750
+ # drive_folder_id (str): Google Drive folder ID where to upload the file.
751
+ # """
752
+ # try:
753
+ # # ✅ Save to temporary local path first
754
+ # print("file name: ", filename)
755
+ # print("length text content: ", len(text_content))
756
+ # local_path = os.path.join(tempfile.gettempdir(), filename)
757
+ # document = Document()
758
+ # for paragraph_text in text_content.split('\n'):
759
+ # document.add_paragraph(paragraph_text)
760
+ # document.save(local_path)
761
+ # print(f"✅ Text saved locally to: {local_path}")
762
+
763
+ # # ✅ Upload to Drive
764
+ # pipeline.upload_file_to_drive(local_path, filename, drive_folder_id)
765
+ # print(f"✅ Uploaded '{filename}' to Google Drive folder ID: {drive_folder_id}")
766
+
767
+ # except Exception as e:
768
+ # print(f"❌ Error saving or uploading DOCX: {e}")
769
+ def save_text_to_docx(text_content: str, full_local_path: str):
770
+ document = Document()
771
+ for paragraph_text in text_content.split('\n'):
772
+ document.add_paragraph(paragraph_text)
773
+ document.save(full_local_path)
774
+ print(f"✅ Saved DOCX locally: {full_local_path}")
775
+
776
+
777
+
778
+ '''2 scenerios:
779
+ - quick look then found then deepdive and directly get location then stop
780
+ - quick look then found then deepdive but not find location then hold the related words then
781
+ look another files iteratively for each related word and find location and stop'''
782
+ def extract_context(text, keyword, window=500):
783
+ # firstly try accession number
784
+ code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
785
+
786
+ # Attempt to parse the keyword into its prefix and numerical part using re.search
787
+ keyword_match = code_pattern.search(keyword)
788
+
789
+ keyword_prefix = None
790
+ keyword_num = None
791
+
792
+ if keyword_match:
793
+ keyword_prefix = keyword_match.group(1).lower()
794
+ keyword_num = int(keyword_match.group(2))
795
+ text = text.lower()
796
+ idx = text.find(keyword.lower())
797
+ if idx == -1:
798
+ if keyword_prefix:
799
+ idx = text.find(keyword_prefix)
800
+ if idx == -1:
801
+ return "Sample ID not found."
802
+ return text[max(0, idx-window): idx+window]
803
+ return text[max(0, idx-window): idx+window]
804
+ def process_inputToken(filePaths, saveLinkFolder,accession=None, isolate=None):
805
+ cache = {}
806
+ country = "unknown"
807
+ output = ""
808
+ tem_output, small_output = "",""
809
+ keyword_appear = (False,"")
810
+ keywords = []
811
+ if isolate: keywords.append(isolate)
812
+ if accession: keywords.append(accession)
813
+ for f in filePaths:
814
+ # scenerio 1: direct location: truncate the context and then use qa model?
815
+ if keywords:
816
+ for keyword in keywords:
817
+ text, tables, final_input = preprocess_document(f,saveLinkFolder, isolate=keyword)
818
+ if keyword in final_input:
819
+ context = extract_context(final_input, keyword)
820
+ # quick look if country already in context and if yes then return
821
+ country = model.get_country_from_text(context)
822
+ if country != "unknown":
823
+ return country, context, final_input
824
+ else:
825
+ country = model.get_country_from_text(final_input)
826
+ if country != "unknown":
827
+ return country, context, final_input
828
+ else: # might be cross-ref
829
+ keyword_appear = (True, f)
830
+ cache[f] = context
831
+ small_output = merge_texts_skipping_overlap(output, context) + "\n"
832
+ chunkBFS = get_contextual_sentences_BFS(small_output, keyword)
833
+ countryBFS = model.get_country_from_text(chunkBFS)
834
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
835
+ output = merge_texts_skipping_overlap(output, final_input)
836
+ if countryDFS != "unknown" and countryBFS != "unknown":
837
+ if len(chunkDFS) <= len(chunkBFS):
838
+ return countryDFS, chunkDFS, output
839
+ else:
840
+ return countryBFS, chunkBFS, output
841
+ else:
842
+ if countryDFS != "unknown":
843
+ return countryDFS, chunkDFS, output
844
+ if countryBFS != "unknown":
845
+ return countryBFS, chunkBFS, output
846
+ else:
847
+ # scenerio 2:
848
+ '''cross-ref: ex: A1YU101 keyword in file 2 which includes KM1 but KM1 in file 1
849
+ but if we look at file 1 first then maybe we can have lookup dict which country
850
+ such as Thailand as the key and its re'''
851
+ cache[f] = final_input
852
+ if keyword_appear[0] == True:
853
+ for c in cache:
854
+ if c!=keyword_appear[1]:
855
+ if cache[c].lower() not in output.lower():
856
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
857
+ chunkBFS = get_contextual_sentences_BFS(output, keyword)
858
+ countryBFS = model.get_country_from_text(chunkBFS)
859
+ countryDFS, chunkDFS = get_contextual_sentences_DFS(output, keyword)
860
+ if countryDFS != "unknown" and countryBFS != "unknown":
861
+ if len(chunkDFS) <= len(chunkBFS):
862
+ return countryDFS, chunkDFS, output
863
+ else:
864
+ return countryBFS, chunkBFS, output
865
+ else:
866
+ if countryDFS != "unknown":
867
+ return countryDFS, chunkDFS, output
868
+ if countryBFS != "unknown":
869
+ return countryBFS, chunkBFS, output
870
+ else:
871
+ if cache[f].lower() not in output.lower():
872
+ output = merge_texts_skipping_overlap(output, cache[f]) + "\n"
873
+ if len(output) == 0 or keyword_appear[0]==False:
874
+ for c in cache:
875
+ if cache[c].lower() not in output.lower():
876
+ output = merge_texts_skipping_overlap(output, cache[c]) + "\n"
877
  return country, "", output
model.py CHANGED
The diff for this file is too large to render. See raw diff
 
mtdna_backend.py CHANGED
@@ -1,1145 +1,1005 @@
1
- import gradio as gr
2
- from collections import Counter
3
- import csv
4
- import os
5
- from functools import lru_cache
6
- #import app
7
- from mtdna_classifier import classify_sample_location
8
- import data_preprocess, model, pipeline
9
- import subprocess
10
- import json
11
- import pandas as pd
12
- import io
13
- import re
14
- import tempfile
15
- import gspread
16
- from oauth2client.service_account import ServiceAccountCredentials
17
- from io import StringIO
18
- import hashlib
19
- import threading
20
-
21
- # @lru_cache(maxsize=3600)
22
- # def classify_sample_location_cached(accession):
23
- # return classify_sample_location(accession)
24
-
25
- #@lru_cache(maxsize=3600)
26
- async def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None):
27
- print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
- print("len of save df: ", len(save_df))
29
- return await pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df)
30
-
31
- # Count and suggest final location
32
- # def compute_final_suggested_location(rows):
33
- # candidates = [
34
- # row.get("Predicted Location", "").strip()
35
- # for row in rows
36
- # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
37
- # ] + [
38
- # row.get("Inferred Region", "").strip()
39
- # for row in rows
40
- # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
41
- # ]
42
-
43
- # if not candidates:
44
- # return Counter(), ("Unknown", 0)
45
- # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
46
- # tokens = []
47
- # for item in candidates:
48
- # # Split by comma, whitespace, and newlines
49
- # parts = re.split(r'[\s,]+', item)
50
- # tokens.extend(parts)
51
-
52
- # # Step 2: Clean and normalize tokens
53
- # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
54
-
55
- # # Step 3: Count
56
- # counts = Counter(tokens)
57
-
58
- # # Step 4: Get most common
59
- # top_location, count = counts.most_common(1)[0]
60
- # return counts, (top_location, count)
61
-
62
- # Store feedback (with required fields)
63
-
64
- def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
65
- if not answer1.strip() or not answer2.strip():
66
- return "⚠️ Please answer both questions before submitting."
67
-
68
- try:
69
- # ✅ Step: Load credentials from Hugging Face secret
70
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
71
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
72
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
73
-
74
- # Connect to Google Sheet
75
- client = gspread.authorize(creds)
76
- sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
77
-
78
- # Append feedback
79
- sheet.append_row([accession, answer1, answer2, contact])
80
- return "✅ Feedback submitted. Thank you!"
81
-
82
- except Exception as e:
83
- return f"❌ Error submitting feedback: {e}"
84
-
85
- import re
86
-
87
- ACCESSION_REGEX = re.compile(r'^[A-Z]{1,4}_?\d{6}(\.\d+)?$')
88
-
89
- def is_valid_accession(acc):
90
- return bool(ACCESSION_REGEX.match(acc))
91
-
92
- # helper function to extract accessions
93
- def extract_accessions_from_input(file=None, raw_text=""):
94
- print(f"RAW TEXT RECEIVED: {raw_text}")
95
- accessions, invalid_accessions = [], []
96
- seen = set()
97
- if file:
98
- try:
99
- if file.name.endswith(".csv"):
100
- df = pd.read_csv(file)
101
- elif file.name.endswith(".xlsx"):
102
- df = pd.read_excel(file)
103
- else:
104
- return [], "Unsupported file format. Please upload CSV or Excel."
105
- for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
106
- if acc not in seen:
107
- if is_valid_accession(acc):
108
- accessions.append(acc)
109
- seen.add(acc)
110
- else:
111
- invalid_accessions.append(acc)
112
-
113
- except Exception as e:
114
- return [],[], f"Failed to read file: {e}"
115
-
116
- if raw_text:
117
- try:
118
- text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
119
- for acc in text_ids:
120
- if acc not in seen:
121
- if is_valid_accession(acc):
122
- accessions.append(acc)
123
- seen.add(acc)
124
- else:
125
- invalid_accessions.append(acc)
126
- except Exception as e:
127
- return [],[], f"Failed to read file: {e}"
128
-
129
- return list(accessions), list(invalid_accessions), None
130
- # ✅ Add a new helper to backend: `filter_unprocessed_accessions()`
131
- def get_incomplete_accessions(file_path):
132
- df = pd.read_excel(file_path)
133
-
134
- incomplete_accessions = []
135
- for _, row in df.iterrows():
136
- sample_id = str(row.get("Sample ID", "")).strip()
137
-
138
- # Skip if no sample ID
139
- if not sample_id:
140
- continue
141
-
142
- # Drop the Sample ID and check if the rest is empty
143
- other_cols = row.drop(labels=["Sample ID"], errors="ignore")
144
- if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
145
- # Extract the accession number from the sample ID using regex
146
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
147
- if match:
148
- incomplete_accessions.append(match.group(0))
149
- print(len(incomplete_accessions))
150
- return incomplete_accessions
151
-
152
- # GOOGLE_SHEET_NAME = "known_samples"
153
- # USAGE_DRIVE_FILENAME = "user_usage_log.json"
154
- def truncate_cell(value, max_len=49000):
155
- """Ensure cell content never exceeds Google Sheets 50k char limit."""
156
- if not isinstance(value, str):
157
- value = str(value)
158
- return value[:max_len] + ("... [TRUNCATED]" if len(value) > max_len else "")
159
-
160
-
161
- async def summarize_results(accession, stop_flag=None):
162
- # Early bail
163
- if stop_flag is not None and stop_flag.value:
164
- print(f"🛑 Skipping {accession} before starting.")
165
- return []
166
- # try cache first
167
- cached = check_known_output(accession)
168
- if cached:
169
- print(f" Using cached result for {accession}")
170
- return [[
171
- cached["Sample ID"] or "unknown",
172
- cached["Predicted Country"] or "unknown",
173
- cached["Country Explanation"] or "unknown",
174
- cached["Predicted Sample Type"] or "unknown",
175
- cached["Sample Type Explanation"] or "unknown",
176
- cached["Sources"] or "No Links",
177
- cached["Time cost"]
178
- ]]
179
- # only run when nothing in the cache
180
- try:
181
- print("try gemini pipeline: ",accession)
182
- # Load credentials from Hugging Face secret
183
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
184
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
185
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
186
- client = gspread.authorize(creds)
187
-
188
- spreadsheet = client.open("known_samples")
189
- sheet = spreadsheet.sheet1
190
-
191
- data = sheet.get_all_values()
192
- if not data:
193
- print("⚠️ Google Sheet 'known_samples' is empty.")
194
- return None
195
-
196
- save_df = pd.DataFrame(data[1:], columns=data[0])
197
- print("before pipeline, len of save df: ", len(save_df))
198
- outputs = await pipeline_classify_sample_location_cached(accession, stop_flag, save_df)
199
- if stop_flag is not None and stop_flag.value:
200
- print(f"🛑 Skipped {accession} mid-pipeline.")
201
- return []
202
- # outputs = {'KU131308': {'isolate':'BRU18',
203
- # 'country': {'brunei': ['ncbi',
204
- # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
205
- # 'sample_type': {'modern':
206
- # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
207
- # 'query_cost': 9.754999999999999e-05,
208
- # 'time_cost': '24.776 seconds',
209
- # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
210
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
211
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
212
- except Exception as e:
213
- return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
214
-
215
- if accession not in outputs:
216
- print("no accession in output ", accession)
217
- return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
218
-
219
- row_score = []
220
- rows = []
221
- save_rows = []
222
- for key in outputs:
223
- pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
224
- for section, results in outputs[key].items():
225
- if section == "country" or section =="sample_type":
226
- pred_output = []#"\n".join(list(results.keys()))
227
- output_explanation = ""
228
- for result, content in results.items():
229
- if len(result) == 0: result = "unknown"
230
- if len(content) == 0: output_explanation = "unknown"
231
- else:
232
- output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
233
- pred_output.append(result)
234
- pred_output = "\n".join(pred_output)
235
- if section == "country":
236
- pred_country, country_explanation = pred_output, output_explanation
237
- elif section == "sample_type":
238
- pred_sample, sample_explanation = pred_output, output_explanation
239
- if outputs[key]["isolate"].lower()!="unknown":
240
- label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
241
- else: label = key
242
- if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
243
- # row = {
244
- # "Sample ID": label or "unknown",
245
- # "Predicted Country": pred_country or "unknown",
246
- # "Country Explanation": country_explanation or "unknown",
247
- # "Predicted Sample Type":pred_sample or "unknown",
248
- # "Sample Type Explanation":sample_explanation or "unknown",
249
- # "Sources": "\n".join(outputs[key]["source"]) or "No Links",
250
- # "Time cost": outputs[key]["time_cost"]
251
- # }
252
- row = {
253
- "Sample ID": truncate_cell(label or "unknown"),
254
- "Predicted Country": truncate_cell(pred_country or "unknown"),
255
- "Country Explanation": truncate_cell(country_explanation or "unknown"),
256
- "Predicted Sample Type": truncate_cell(pred_sample or "unknown"),
257
- "Sample Type Explanation": truncate_cell(sample_explanation or "unknown"),
258
- "Sources": truncate_cell("\n".join(outputs[key]["source"]) or "No Links"),
259
- "Time cost": truncate_cell(outputs[key]["time_cost"])
260
- }
261
- #row_score.append(row)
262
- rows.append(list(row.values()))
263
-
264
- # save_row = {
265
- # "Sample ID": label or "unknown",
266
- # "Predicted Country": pred_country or "unknown",
267
- # "Country Explanation": country_explanation or "unknown",
268
- # "Predicted Sample Type":pred_sample or "unknown",
269
- # "Sample Type Explanation":sample_explanation or "unknown",
270
- # "Sources": "\n".join(outputs[key]["source"]) or "No Links",
271
- # "Query_cost": outputs[key]["query_cost"] or "",
272
- # "Time cost": outputs[key]["time_cost"] or "",
273
- # "file_chunk":outputs[key]["file_chunk"] or "",
274
- # "file_all_output":outputs[key]["file_all_output"] or ""
275
- # }
276
- save_row = {
277
- "Sample ID": truncate_cell(label or "unknown"),
278
- "Predicted Country": truncate_cell(pred_country or "unknown"),
279
- "Country Explanation": truncate_cell(country_explanation or "unknown"),
280
- "Predicted Sample Type": truncate_cell(pred_sample or "unknown"),
281
- "Sample Type Explanation": truncate_cell(sample_explanation or "unknown"),
282
- "Sources": truncate_cell("\n".join(outputs[key]["source"]) or "No Links"),
283
- "Query_cost": outputs[key]["query_cost"] or "",
284
- "Time cost": outputs[key]["time_cost"] or "",
285
- "file_chunk": truncate_cell(outputs[key]["file_chunk"] or ""),
286
- "file_all_output": truncate_cell(outputs[key]["file_all_output"] or "")
287
- }
288
-
289
- #row_score.append(row)
290
- save_rows.append(list(save_row.values()))
291
-
292
- # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
293
- # summary_lines = [f"### 🧭 Location Summary:\n"]
294
- # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
295
- # summary_lines.append(f"\n**Final Suggested Location:** 🗺️ **{final_location}** (mentioned {count} times)")
296
- # summary = "\n".join(summary_lines)
297
-
298
- # save the new running sample to known excel file
299
- # try:
300
- # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
301
- # if os.path.exists(KNOWN_OUTPUT_PATH):
302
- # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
303
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
304
- # else:
305
- # df_combined = df_new
306
- # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
307
- # except Exception as e:
308
- # print(f"⚠️ Failed to save known output: {e}")
309
- # try:
310
- # df_new = pd.DataFrame(save_rows, columns=[
311
- # "Sample ID", "Predicted Country", "Country Explanation",
312
- # "Predicted Sample Type", "Sample Type Explanation",
313
- # "Sources", "Query_cost", "Time cost"
314
- # ])
315
-
316
- # # Google Sheets API setup
317
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
318
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
319
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
320
- # client = gspread.authorize(creds)
321
-
322
- # # Open the known_samples sheet
323
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
324
- # sheet = spreadsheet.sheet1
325
-
326
- # # ✅ Read old data
327
- # existing_data = sheet.get_all_values()
328
- # if existing_data:
329
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
330
- # else:
331
- # df_old = pd.DataFrame(columns=df_new.columns)
332
-
333
- # # Combine and remove duplicates
334
- # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
335
-
336
- # # Clear and write back
337
- # sheet.clear()
338
- # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
339
-
340
- # except Exception as e:
341
- # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
342
- # try:
343
- # # Prepare as DataFrame
344
- # df_new = pd.DataFrame(save_rows, columns=[
345
- # "Sample ID", "Predicted Country", "Country Explanation",
346
- # "Predicted Sample Type", "Sample Type Explanation",
347
- # "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
348
- # ])
349
-
350
- # # ✅ Setup Google Sheets
351
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
352
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
353
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
354
- # client = gspread.authorize(creds)
355
- # spreadsheet = client.open("known_samples")
356
- # sheet = spreadsheet.sheet1
357
-
358
- # # ✅ Read existing data
359
- # existing_data = sheet.get_all_values()
360
- # headers = existing_data[0]
361
-
362
- # if existing_data:
363
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
364
-
365
- # else:
366
-
367
- # df_old = pd.DataFrame(columns=[
368
- # "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
369
- # "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
370
- # "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output"
371
- # ])
372
-
373
-
374
- # # ✅ Index by Sample ID
375
- # df_old.set_index("Sample ID", inplace=True)
376
- # df_new.set_index("Sample ID", inplace=True)
377
-
378
- # # Update only matching fields
379
- # update_columns = [
380
- # "Predicted Country", "Predicted Sample Type", "Country Explanation",
381
- # "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
382
- # ]
383
- # for idx, row in df_new.iterrows():
384
- # if idx not in df_old.index:
385
- # df_old.loc[idx] = "" # new row, fill empty first
386
- # for col in update_columns:
387
- # if pd.notna(row[col]) and row[col] != "":
388
- # df_old.at[idx, col] = row[col]
389
-
390
- # # ✅ Reset and write back
391
- # EXPECTED_COLUMNS = [
392
- # "Sample ID", "Predicted Country", "Country Explanation",
393
- # "Predicted Sample Type", "Sample Type Explanation",
394
- # "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
395
- # ]
396
-
397
- # # Force schema
398
- # for col in EXPECTED_COLUMNS:
399
- # if col not in df_old.columns:
400
- # df_old[col] = ""
401
-
402
- # df_old = df_old[EXPECTED_COLUMNS].reset_index(inplace=True) # reorder + drop unexpected
403
-
404
- # # ✅ Safe update
405
- # sheet.clear()
406
- # sheet.update([EXPECTED_COLUMNS] + df_old.astype(str).values.tolist())
407
-
408
- # # df_old.reset_index(inplace=True)
409
- # # sheet.clear()
410
- # # sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
411
- # print(" Match results saved to known_samples.")
412
-
413
- # except Exception as e:
414
- # print(f"❌ Failed to update known_samples: {e}")
415
- try:
416
- # Prepare as DataFrame
417
- df_new = pd.DataFrame(save_rows, columns=[
418
- "Sample ID", "Predicted Country", "Country Explanation",
419
- "Predicted Sample Type", "Sample Type Explanation",
420
- "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
421
- ])
422
-
423
- # Setup Google Sheets
424
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
425
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
426
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
427
- client = gspread.authorize(creds)
428
- spreadsheet = client.open("known_samples")
429
- sheet = spreadsheet.sheet1
430
-
431
- # Load existing data
432
- existing_data = sheet.get_all_values()
433
- headers = existing_data[0]
434
- existing_df = pd.DataFrame(existing_data[1:], columns=headers)
435
-
436
- # Build lookup: Sample ID → row index
437
- id_to_row = {sid: i+2 for i, sid in enumerate(existing_df["Sample ID"])}
438
- # +2 because gspread is 1-based and row 1 is headers
439
-
440
- for _, row in df_new.iterrows():
441
- sid = row["Sample ID"]
442
-
443
- # Row values in correct schema order
444
- # row_values = [
445
- # row.get("Sample ID", ""),
446
- # row.get("Predicted Country", ""),
447
- # row.get("Country Explanation", ""),
448
- # row.get("Predicted Sample Type", ""),
449
- # row.get("Sample Type Explanation", ""),
450
- # row.get("Sources", ""),
451
- # row.get("Query_cost", ""),
452
- # row.get("Time cost", ""),
453
- # row.get("file_chunk", ""),
454
- # row.get("file_all_output", "")
455
- # ]
456
- row_values = [
457
- truncate_cell(row.get("Sample ID", "")),
458
- truncate_cell(row.get("Predicted Country", "")),
459
- truncate_cell(row.get("Country Explanation", "")),
460
- truncate_cell(row.get("Predicted Sample Type", "")),
461
- truncate_cell(row.get("Sample Type Explanation", "")),
462
- truncate_cell(row.get("Sources", "")),
463
- truncate_cell(row.get("Query_cost", "")),
464
- truncate_cell(row.get("Time cost", "")),
465
- truncate_cell(row.get("file_chunk", "")),
466
- truncate_cell(row.get("file_all_output", ""))
467
- ]
468
-
469
-
470
- if sid in id_to_row:
471
- # Update existing row
472
- sheet.update(f"A{id_to_row[sid]}:J{id_to_row[sid]}", [row_values])
473
- else:
474
- # Append new row
475
- sheet.append_row(row_values)
476
-
477
- print("✅ Match results safely saved to known_samples.")
478
-
479
- except Exception as e:
480
- print(f"❌ Failed to update known_samples: {e}")
481
-
482
-
483
- return rows#, summary, labelAncient_Modern, explain_label
484
-
485
- # save the batch input in excel file
486
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
487
- # with pd.ExcelWriter(filename) as writer:
488
- # # Save table
489
- # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
490
- # df.to_excel(writer, sheet_name="Detailed Results", index=False)
491
- # try:
492
- # df_old = pd.read_excel(filename)
493
- # except:
494
- # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
495
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
496
- # # if os.path.exists(filename):
497
- # # df_old = pd.read_excel(filename)
498
- # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
499
- # # else:
500
- # # df_combined = df_new
501
- # df_combined.to_excel(filename, index=False)
502
- # # # Save summary
503
- # # summary_df = pd.DataFrame({"Summary": [summary_text]})
504
- # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
505
-
506
- # # # Save flag
507
- # # flag_df = pd.DataFrame({"Flag": [flag_text]})
508
- # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
509
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
510
- # df_new = pd.DataFrame(all_rows, columns=[
511
- # "Sample ID", "Predicted Country", "Country Explanation",
512
- # "Predicted Sample Type", "Sample Type Explanation",
513
- # "Sources", "Time cost"
514
- # ])
515
-
516
- # try:
517
- # if os.path.exists(filename):
518
- # df_old = pd.read_excel(filename)
519
- # else:
520
- # df_old = pd.DataFrame(columns=df_new.columns)
521
- # except Exception as e:
522
- # print(f"⚠️ Warning reading old Excel file: {e}")
523
- # df_old = pd.DataFrame(columns=df_new.columns)
524
-
525
- # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
526
- # df_old.set_index("Sample ID", inplace=True)
527
- # df_new.set_index("Sample ID", inplace=True)
528
-
529
- # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
530
-
531
- # df_combined = df_old.reset_index()
532
-
533
- # try:
534
- # df_combined.to_excel(filename, index=False)
535
- # except Exception as e:
536
- # print(f" Failed to write Excel file {filename}: {e}")
537
- def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
538
- df_new = pd.DataFrame(all_rows, columns=[
539
- "Sample ID", "Predicted Country", "Country Explanation",
540
- "Predicted Sample Type", "Sample Type Explanation",
541
- "Sources", "Time cost"
542
- ])
543
-
544
- if is_resume and os.path.exists(filename):
545
- try:
546
- df_old = pd.read_excel(filename)
547
- except Exception as e:
548
- print(f"⚠️ Warning reading old Excel file: {e}")
549
- df_old = pd.DataFrame(columns=df_new.columns)
550
-
551
- # Set index and update existing rows
552
- df_old.set_index("Sample ID", inplace=True)
553
- df_new.set_index("Sample ID", inplace=True)
554
- df_old.update(df_new)
555
-
556
- df_combined = df_old.reset_index()
557
- else:
558
- # If not resuming or file doesn't exist, just use new rows
559
- df_combined = df_new
560
-
561
- try:
562
- df_combined.to_excel(filename, index=False)
563
- except Exception as e:
564
- print(f"❌ Failed to write Excel file {filename}: {e}")
565
-
566
-
567
- # save the batch input in JSON file
568
- def save_to_json(all_rows, summary_text, flag_text, filename):
569
- output_dict = {
570
- "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
571
- # "Summary_Text": summary_text,
572
- # "Ancient_Modern_Flag": flag_text
573
- }
574
-
575
- # If all_rows is a DataFrame, convert it
576
- if isinstance(all_rows, pd.DataFrame):
577
- output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
578
-
579
- with open(filename, "w") as external_file:
580
- json.dump(output_dict, external_file, indent=2)
581
-
582
- # save the batch input in Text file
583
- def save_to_txt(all_rows, summary_text, flag_text, filename):
584
- if isinstance(all_rows, pd.DataFrame):
585
- detailed_results = all_rows.to_dict(orient="records")
586
- output = ""
587
- #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
588
- output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
589
- for r in detailed_results:
590
- output += ",".join([str(v) for v in r.values()]) + "\n\n"
591
- with open(filename, "w") as f:
592
- f.write("=== Detailed Results ===\n")
593
- f.write(output + "\n")
594
-
595
- # f.write("\n=== Summary ===\n")
596
- # f.write(summary_text + "\n")
597
-
598
- # f.write("\n=== Ancient/Modern Flag ===\n")
599
- # f.write(flag_text + "\n")
600
-
601
- def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
602
- tmp_dir = tempfile.mkdtemp()
603
-
604
- #html_table = all_rows.value # assuming this is stored somewhere
605
-
606
- # Parse back to DataFrame
607
- #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
608
- all_rows = pd.read_html(StringIO(all_rows))[0]
609
- print(all_rows)
610
-
611
- if output_type == "Excel":
612
- file_path = f"{tmp_dir}/batch_output.xlsx"
613
- save_to_excel(all_rows, summary_text, flag_text, file_path)
614
- elif output_type == "JSON":
615
- file_path = f"{tmp_dir}/batch_output.json"
616
- save_to_json(all_rows, summary_text, flag_text, file_path)
617
- print("Done with JSON")
618
- elif output_type == "TXT":
619
- file_path = f"{tmp_dir}/batch_output.txt"
620
- save_to_txt(all_rows, summary_text, flag_text, file_path)
621
- else:
622
- return gr.update(visible=False) # invalid option
623
-
624
- return gr.update(value=file_path, visible=True)
625
- # save cost by checking the known outputs
626
-
627
- # def check_known_output(accession):
628
- # if not os.path.exists(KNOWN_OUTPUT_PATH):
629
- # return None
630
-
631
- # try:
632
- # df = pd.read_excel(KNOWN_OUTPUT_PATH)
633
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
634
- # if match:
635
- # accession = match.group(0)
636
-
637
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
638
- # if not matched.empty:
639
- # return matched.iloc[0].to_dict() # Return the cached row
640
- # except Exception as e:
641
- # print(f"⚠️ Failed to load known samples: {e}")
642
- # return None
643
-
644
- # def check_known_output(accession):
645
- # try:
646
- # # Load credentials from Hugging Face secret
647
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
648
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
649
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
650
- # client = gspread.authorize(creds)
651
-
652
- # # ✅ Open the known_samples sheet
653
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
654
- # sheet = spreadsheet.sheet1
655
-
656
- # # ✅ Read all rows
657
- # data = sheet.get_all_values()
658
- # if not data:
659
- # return None
660
-
661
- # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
662
-
663
- # # ✅ Normalize accession pattern
664
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
665
- # if match:
666
- # accession = match.group(0)
667
-
668
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
669
- # if not matched.empty:
670
- # return matched.iloc[0].to_dict()
671
-
672
- # except Exception as e:
673
- # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
674
- # return None
675
- # def check_known_output(accession):
676
- # print("inside check known output function")
677
- # try:
678
- # # Load credentials from Hugging Face secret
679
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
680
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
681
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
682
- # client = gspread.authorize(creds)
683
-
684
- # spreadsheet = client.open("known_samples")
685
- # sheet = spreadsheet.sheet1
686
-
687
- # data = sheet.get_all_values()
688
- # if not data:
689
- # print("⚠️ Google Sheet 'known_samples' is empty.")
690
- # return None
691
-
692
- # df = pd.DataFrame(data[1:], columns=data[0])
693
- # if "Sample ID" not in df.columns:
694
- # print("❌ Column 'Sample ID' not found in Google Sheet.")
695
- # return None
696
-
697
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
698
- # if match:
699
- # accession = match.group(0)
700
-
701
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
702
- # if not matched.empty:
703
- # #return matched.iloc[0].to_dict()
704
- # row = matched.iloc[0]
705
- # country = row.get("Predicted Country", "").strip().lower()
706
- # sample_type = row.get("Predicted Sample Type", "").strip().lower()
707
-
708
- # if country and country != "unknown" and sample_type and sample_type != "unknown":
709
- # return row.to_dict()
710
- # else:
711
- # print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
712
- # return None
713
- # else:
714
- # print(f"🔍 Accession {accession} not found in known_samples.")
715
- # return None
716
-
717
- # except Exception as e:
718
- # import traceback
719
- # print("❌ Exception occurred during check_known_output:")
720
- # traceback.print_exc()
721
- # return None
722
-
723
- import os
724
- import re
725
- import json
726
- import time
727
- import gspread
728
- import pandas as pd
729
- from oauth2client.service_account import ServiceAccountCredentials
730
- from gspread.exceptions import APIError
731
-
732
- # --- Global cache ---
733
- _known_samples_cache = None
734
-
735
- def load_known_samples():
736
- """Load the Google Sheet 'known_samples' into a Pandas DataFrame and cache it."""
737
- global _known_samples_cache
738
- try:
739
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
740
- scope = [
741
- 'https://spreadsheets.google.com/feeds',
742
- 'https://www.googleapis.com/auth/drive'
743
- ]
744
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
745
- client = gspread.authorize(creds)
746
-
747
- sheet = client.open("known_samples").sheet1
748
- data = sheet.get_all_values()
749
-
750
- if not data:
751
- print("⚠️ Google Sheet 'known_samples' is empty.")
752
- _known_samples_cache = pd.DataFrame()
753
- else:
754
- _known_samples_cache = pd.DataFrame(data[1:], columns=data[0])
755
- print(f"✅ Cached {_known_samples_cache.shape[0]} rows from known_samples")
756
-
757
- except APIError as e:
758
- print(f"❌ APIError while loading known_samples: {e}")
759
- _known_samples_cache = pd.DataFrame()
760
- except Exception as e:
761
- import traceback
762
- print("❌ Exception occurred while loading known_samples:")
763
- traceback.print_exc()
764
- _known_samples_cache = pd.DataFrame()
765
-
766
- def check_known_output(accession):
767
- """Check if an accession exists in the cached 'known_samples' sheet."""
768
- global _known_samples_cache
769
- print("inside check known output function")
770
-
771
- try:
772
- # Load cache if not already loaded
773
- if _known_samples_cache is None:
774
- load_known_samples()
775
-
776
- if _known_samples_cache.empty:
777
- print("⚠️ No cached data available.")
778
- return None
779
-
780
- # Extract proper accession format (e.g. AB12345)
781
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
782
- if match:
783
- accession = match.group(0)
784
-
785
- matched = _known_samples_cache[
786
- _known_samples_cache["Sample ID"].str.contains(accession, case=False, na=False)
787
- ]
788
-
789
- if not matched.empty:
790
- row = matched.iloc[0]
791
- country = row.get("Predicted Country", "").strip().lower()
792
- sample_type = row.get("Predicted Sample Type", "").strip().lower()
793
-
794
- if country and country != "unknown" and sample_type and sample_type != "unknown":
795
- print(f"🎯 Found {accession} in cache")
796
- return row.to_dict()
797
- else:
798
- print(f"⚠️ Accession {accession} found but country/sample_type unknown or empty.")
799
- return None
800
- else:
801
- print(f"🔍 Accession {accession} not found in cache.")
802
- return None
803
-
804
- except Exception as e:
805
- import traceback
806
- print("❌ Exception occurred during check_known_output:")
807
- traceback.print_exc()
808
- return None
809
-
810
-
811
-
812
- def hash_user_id(user_input):
813
- return hashlib.sha256(user_input.encode()).hexdigest()
814
-
815
- # Load and save usage count
816
-
817
- # def load_user_usage():
818
- # if not os.path.exists(USER_USAGE_TRACK_FILE):
819
- # return {}
820
-
821
- # try:
822
- # with open(USER_USAGE_TRACK_FILE, "r") as f:
823
- # content = f.read().strip()
824
- # if not content:
825
- # return {} # file is empty
826
- # return json.loads(content)
827
- # except (json.JSONDecodeError, ValueError):
828
- # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
829
- # return {} # fallback to empty dict
830
- # def load_user_usage():
831
- # try:
832
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
833
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
834
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
835
- # client = gspread.authorize(creds)
836
-
837
- # sheet = client.open("user_usage_log").sheet1
838
- # data = sheet.get_all_records() # Assumes columns: email, usage_count
839
-
840
- # usage = {}
841
- # for row in data:
842
- # email = row.get("email", "").strip().lower()
843
- # count = int(row.get("usage_count", 0))
844
- # if email:
845
- # usage[email] = count
846
- # return usage
847
- # except Exception as e:
848
- # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
849
- # return {}
850
- # def load_user_usage():
851
- # try:
852
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
853
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
854
-
855
- # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
856
- # if not found:
857
- # return {} # not found, start fresh
858
-
859
- # #file_id = found[0]["id"]
860
- # file_id = found
861
- # content = pipeline.download_drive_file_content(file_id)
862
- # return json.loads(content.strip()) if content.strip() else {}
863
-
864
- # except Exception as e:
865
- # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
866
- # return {}
867
- def load_user_usage():
868
- try:
869
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
870
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
871
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
872
- client = gspread.authorize(creds)
873
-
874
- sheet = client.open("user_usage_log").sheet1
875
- data = sheet.get_all_values()
876
- print("data: ", data)
877
- print("🧪 Raw header row from sheet:", data[0])
878
- print("🧪 Character codes in each header:")
879
- for h in data[0]:
880
- print([ord(c) for c in h])
881
-
882
- if not data or len(data) < 2:
883
- print("⚠️ Sheet is empty or missing rows.")
884
- return {}
885
-
886
- headers = [h.strip().lower() for h in data[0]]
887
- if "email" not in headers or "usage_count" not in headers:
888
- print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
889
- return {}
890
-
891
- permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
892
- df = pd.DataFrame(data[1:], columns=headers)
893
-
894
- usage = {}
895
- permitted = {}
896
- for _, row in df.iterrows():
897
- email = row.get("email", "").strip().lower()
898
- try:
899
- #count = int(row.get("usage_count", 0))
900
- try:
901
- count = int(float(row.get("usage_count", 0)))
902
- except Exception:
903
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
904
- count = 0
905
-
906
- if email:
907
- usage[email] = count
908
- if permitted_index is not None:
909
- try:
910
- permitted_count = int(float(row.get("permitted_samples", 50)))
911
- permitted[email] = permitted_count
912
- except:
913
- permitted[email] = 50
914
-
915
- except ValueError:
916
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
917
- return usage, permitted
918
-
919
- except Exception as e:
920
- print(f"❌ Error in load_user_usage: {e}")
921
- return {}, {}
922
-
923
-
924
-
925
- # def save_user_usage(usage):
926
- # with open(USER_USAGE_TRACK_FILE, "w") as f:
927
- # json.dump(usage, f, indent=2)
928
-
929
- # def save_user_usage(usage_dict):
930
- # try:
931
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
932
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
933
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
934
- # client = gspread.authorize(creds)
935
-
936
- # sheet = client.open("user_usage_log").sheet1
937
- # sheet.clear() # clear old contents first
938
-
939
- # # Write header + rows
940
- # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
941
- # sheet.update(rows)
942
- # except Exception as e:
943
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
944
- # def save_user_usage(usage_dict):
945
- # try:
946
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
947
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
948
-
949
- # import tempfile
950
- # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
951
- # print("💾 Saving this usage dict:", usage_dict)
952
- # with open(tmp_path, "w") as f:
953
- # json.dump(usage_dict, f, indent=2)
954
-
955
- # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
956
-
957
- # except Exception as e:
958
- # print(f" Failed to save user_usage_log.json to Google Drive: {e}")
959
- # def save_user_usage(usage_dict):
960
- # try:
961
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
962
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
963
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
964
- # client = gspread.authorize(creds)
965
-
966
- # spreadsheet = client.open("user_usage_log")
967
- # sheet = spreadsheet.sheet1
968
-
969
- # # Step 1: Convert new usage to DataFrame
970
- # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
971
- # df_new["email"] = df_new["email"].str.strip().str.lower()
972
-
973
- # # Step 2: Load existing data
974
- # existing_data = sheet.get_all_values()
975
- # print("🧪 Sheet existing_data:", existing_data)
976
-
977
- # # Try to load old data
978
- # if existing_data and len(existing_data[0]) >= 1:
979
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
980
-
981
- # # Fix missing columns
982
- # if "email" not in df_old.columns:
983
- # df_old["email"] = ""
984
- # if "usage_count" not in df_old.columns:
985
- # df_old["usage_count"] = 0
986
-
987
- # df_old["email"] = df_old["email"].str.strip().str.lower()
988
- # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
989
- # else:
990
- # df_old = pd.DataFrame(columns=["email", "usage_count"])
991
-
992
- # # Step 3: Merge
993
- # df_combined = pd.concat([df_old, df_new], ignore_index=True)
994
- # df_combined = df_combined.groupby("email", as_index=False).sum()
995
-
996
- # # Step 4: Write back
997
- # sheet.clear()
998
- # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
999
- # print("✅ Saved user usage to user_usage_log sheet.")
1000
-
1001
- # except Exception as e:
1002
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
1003
- def save_user_usage(usage_dict):
1004
- try:
1005
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
1006
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
1007
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
1008
- client = gspread.authorize(creds)
1009
-
1010
- spreadsheet = client.open("user_usage_log")
1011
- sheet = spreadsheet.sheet1
1012
-
1013
- # Build new df
1014
- df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
1015
- df_new["email"] = df_new["email"].str.strip().str.lower()
1016
- df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
1017
-
1018
- # Read existing data
1019
- existing_data = sheet.get_all_values()
1020
- if existing_data and len(existing_data[0]) >= 2:
1021
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
1022
- df_old["email"] = df_old["email"].str.strip().str.lower()
1023
- df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
1024
- else:
1025
- df_old = pd.DataFrame(columns=["email", "usage_count"])
1026
-
1027
- # ✅ Overwrite specific emails only
1028
- df_old = df_old.set_index("email")
1029
- for email, count in usage_dict.items():
1030
- email = email.strip().lower()
1031
- df_old.loc[email, "usage_count"] = count
1032
- df_old = df_old.reset_index()
1033
-
1034
- # Save
1035
- sheet.clear()
1036
- sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
1037
- print("✅ Saved user usage to user_usage_log sheet.")
1038
-
1039
- except Exception as e:
1040
- print(f"❌ Failed to save user usage to Google Sheets: {e}")
1041
-
1042
-
1043
-
1044
-
1045
- # def increment_usage(user_id, num_samples=1):
1046
- # usage = load_user_usage()
1047
- # if user_id not in usage:
1048
- # usage[user_id] = 0
1049
- # usage[user_id] += num_samples
1050
- # save_user_usage(usage)
1051
- # return usage[user_id]
1052
- # def increment_usage(email: str, count: int):
1053
- # usage = load_user_usage()
1054
- # email_key = email.strip().lower()
1055
- # usage[email_key] = usage.get(email_key, 0) + count
1056
- # save_user_usage(usage)
1057
- # return usage[email_key]
1058
- def increment_usage(email: str, count: int = 1):
1059
- usage, permitted = load_user_usage()
1060
- email_key = email.strip().lower()
1061
- #usage[email_key] = usage.get(email_key, 0) + count
1062
- current = usage.get(email_key, 0)
1063
- new_value = current + count
1064
- max_allowed = permitted.get(email_key) or 50
1065
- usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower
1066
- print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
1067
- print("max allow is: ", max_allowed)
1068
- save_user_usage(usage)
1069
- return usage[email_key], max_allowed
1070
-
1071
-
1072
- # run the batch
1073
- def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
1074
- stop_flag=None, output_file_path=None,
1075
- limited_acc=50, yield_callback=None):
1076
- if user_email:
1077
- limited_acc += 10
1078
- accessions, error = extract_accessions_from_input(file, raw_text)
1079
- if error:
1080
- #return [], "", "", f"Error: {error}"
1081
- return [], f"Error: {error}", 0, "", ""
1082
- if resume_file:
1083
- accessions = get_incomplete_accessions(resume_file)
1084
- tmp_dir = tempfile.mkdtemp()
1085
- if not output_file_path:
1086
- if resume_file:
1087
- output_file_path = os.path.join(tmp_dir, resume_file)
1088
- else:
1089
- output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
1090
-
1091
- all_rows = []
1092
- # all_summaries = []
1093
- # all_flags = []
1094
- progress_lines = []
1095
- warning = ""
1096
- if len(accessions) > limited_acc:
1097
- accessions = accessions[:limited_acc]
1098
- warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
1099
- for i, acc in enumerate(accessions):
1100
- if stop_flag and stop_flag.value:
1101
- line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
1102
- progress_lines.append(line)
1103
- if yield_callback:
1104
- yield_callback(line)
1105
- print("🛑 User requested stop.")
1106
- break
1107
- print(f"[{i+1}/{len(accessions)}] Processing {acc}")
1108
- try:
1109
- # rows, summary, label, explain = summarize_results(acc)
1110
- rows = summarize_results(acc)
1111
- all_rows.extend(rows)
1112
- # all_summaries.append(f"**{acc}**\n{summary}")
1113
- # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
1114
- #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
1115
- save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
1116
- line = f"✅ Processed {acc} ({i+1}/{len(accessions)})"
1117
- progress_lines.append(line)
1118
- if yield_callback:
1119
- yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
1120
- except Exception as e:
1121
- print(f"❌ Failed to process {acc}: {e}")
1122
- continue
1123
- #all_summaries.append(f"**{acc}**: Failed - {e}")
1124
- #progress_lines.append(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
1125
- limited_acc -= 1
1126
- """for row in all_rows:
1127
- source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
1128
-
1129
- if source_column.startswith("http"): # Check if the source is a URL
1130
- # Wrap it with HTML anchor tags to make it clickable
1131
- row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
1132
- if not warning:
1133
- warning = f"You only have {limited_acc} left"
1134
- if user_email.strip():
1135
- user_hash = hash_user_id(user_email)
1136
- total_queries = increment_usage(user_hash, len(all_rows))
1137
- else:
1138
- total_queries = 0
1139
- yield_callback("✅ Finished!")
1140
-
1141
- # summary_text = "\n\n---\n\n".join(all_summaries)
1142
- # flag_text = "\n\n---\n\n".join(all_flags)
1143
- #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
1144
- #return all_rows, gr.update(visible=True), gr.update(visible=False)
1145
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
 
1
+ import gradio as gr
2
+ from collections import Counter
3
+ import csv
4
+ import os
5
+ from functools import lru_cache
6
+ #import app
7
+ from mtdna_classifier import classify_sample_location
8
+ import data_preprocess, model, pipeline
9
+ import subprocess
10
+ import json
11
+ import pandas as pd
12
+ import io
13
+ import re
14
+ import tempfile
15
+ import gspread
16
+ from oauth2client.service_account import ServiceAccountCredentials
17
+ from io import StringIO
18
+ import hashlib
19
+ import threading
20
+
21
+ # @lru_cache(maxsize=3600)
22
+ # def classify_sample_location_cached(accession):
23
+ # return classify_sample_location(accession)
24
+
25
+ #@lru_cache(maxsize=3600)
26
+ async def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None, niche_cases=None):
27
+ print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
+ print("len of save df: ", len(save_df))
29
+ if niche_cases: niche_cases=niche_cases.split(", ")
30
+ print("niche case in mtdna_backend.pipeline: ", niche_cases)
31
+ return await pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df, niche_cases=niche_cases)
32
+
33
+ # Count and suggest final location
34
+ # def compute_final_suggested_location(rows):
35
+ # candidates = [
36
+ # row.get("Predicted Location", "").strip()
37
+ # for row in rows
38
+ # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
39
+ # ] + [
40
+ # row.get("Inferred Region", "").strip()
41
+ # for row in rows
42
+ # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
43
+ # ]
44
+
45
+ # if not candidates:
46
+ # return Counter(), ("Unknown", 0)
47
+ # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
48
+ # tokens = []
49
+ # for item in candidates:
50
+ # # Split by comma, whitespace, and newlines
51
+ # parts = re.split(r'[\s,]+', item)
52
+ # tokens.extend(parts)
53
+
54
+ # # Step 2: Clean and normalize tokens
55
+ # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
56
+
57
+ # # Step 3: Count
58
+ # counts = Counter(tokens)
59
+
60
+ # # Step 4: Get most common
61
+ # top_location, count = counts.most_common(1)[0]
62
+ # return counts, (top_location, count)
63
+
64
+ # Store feedback (with required fields)
65
+
66
+ def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
67
+ if not answer1.strip() or not answer2.strip():
68
+ return "⚠️ Please answer both questions before submitting."
69
+
70
+ try:
71
+ # Step: Load credentials from Hugging Face secret
72
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
73
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
74
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
75
+
76
+ # Connect to Google Sheet
77
+ client = gspread.authorize(creds)
78
+ sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
79
+
80
+ # Append feedback
81
+ sheet.append_row([accession, answer1, answer2, contact])
82
+ return "✅ Feedback submitted. Thank you!"
83
+
84
+ except Exception as e:
85
+ return f"❌ Error submitting feedback: {e}"
86
+
87
+ import re
88
+
89
+ ACCESSION_REGEX = re.compile(r'^[A-Z]{1,4}_?\d{6}(\.\d+)?$')
90
+
91
+ def is_valid_accession(acc):
92
+ return bool(ACCESSION_REGEX.match(acc))
93
+
94
+ # helper function to extract accessions
95
+ def extract_accessions_from_input(file=None, raw_text=""):
96
+ print(f"RAW TEXT RECEIVED: {raw_text}")
97
+ accessions, invalid_accessions = [], []
98
+ seen = set()
99
+ if file:
100
+ try:
101
+ if file.name.endswith(".csv"):
102
+ df = pd.read_csv(file)
103
+ elif file.name.endswith(".xlsx"):
104
+ df = pd.read_excel(file)
105
+ else:
106
+ return [], "Unsupported file format. Please upload CSV or Excel."
107
+ for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
108
+ if acc not in seen:
109
+ if is_valid_accession(acc):
110
+ accessions.append(acc)
111
+ seen.add(acc)
112
+ else:
113
+ invalid_accessions.append(acc)
114
+
115
+ except Exception as e:
116
+ return [],[], f"Failed to read file: {e}"
117
+
118
+ if raw_text:
119
+ try:
120
+ text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
121
+ for acc in text_ids:
122
+ if acc not in seen:
123
+ if is_valid_accession(acc):
124
+ accessions.append(acc)
125
+ seen.add(acc)
126
+ else:
127
+ invalid_accessions.append(acc)
128
+ except Exception as e:
129
+ return [],[], f"Failed to read file: {e}"
130
+
131
+ return list(accessions), list(invalid_accessions), None
132
+ # Add a new helper to backend: `filter_unprocessed_accessions()`
133
+ def get_incomplete_accessions(file_path):
134
+ df = pd.read_excel(file_path)
135
+
136
+ incomplete_accessions = []
137
+ for _, row in df.iterrows():
138
+ sample_id = str(row.get("Sample ID", "")).strip()
139
+
140
+ # Skip if no sample ID
141
+ if not sample_id:
142
+ continue
143
+
144
+ # Drop the Sample ID and check if the rest is empty
145
+ other_cols = row.drop(labels=["Sample ID"], errors="ignore")
146
+ if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
147
+ # Extract the accession number from the sample ID using regex
148
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
149
+ if match:
150
+ incomplete_accessions.append(match.group(0))
151
+ print(len(incomplete_accessions))
152
+ return incomplete_accessions
153
+
154
+ # GOOGLE_SHEET_NAME = "known_samples"
155
+ # USAGE_DRIVE_FILENAME = "user_usage_log.json"
156
+ def truncate_cell(value, max_len=49000):
157
+ """Ensure cell content never exceeds Google Sheets 50k char limit."""
158
+ if not isinstance(value, str):
159
+ value = str(value)
160
+ return value[:max_len] + ("... [TRUNCATED]" if len(value) > max_len else "")
161
+
162
+
163
+ async def summarize_results(accession, stop_flag=None, niche_cases=None):
164
+ # Early bail
165
+ if stop_flag is not None and stop_flag.value:
166
+ print(f"🛑 Skipping {accession} before starting.")
167
+ return []
168
+ # try cache first
169
+ print("niche case in sum_result: ", niche_cases)
170
+ cached = check_known_output(accession, niche_cases)
171
+ if cached:
172
+ print(f" Using cached result for {accession}")
173
+ return [[
174
+ cached["Sample ID"] or "unknown",
175
+ cached["Predicted Country"] or "unknown",
176
+ cached["Country Explanation"] or "unknown",
177
+ cached["Predicted Sample Type"] or "unknown",
178
+ cached["Sample Type Explanation"] or "unknown",
179
+ cached["Sources"] or "No Links",
180
+ cached["Time cost"]
181
+ ]]
182
+ # only run when nothing in the cache
183
+ try:
184
+ print("try gemini pipeline: ",accession)
185
+ # Load credentials from Hugging Face secret
186
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
187
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
188
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
189
+ client = gspread.authorize(creds)
190
+
191
+ spreadsheet = client.open("known_samples")
192
+ sheet = spreadsheet.sheet1
193
+
194
+ data = sheet.get_all_values()
195
+ if not data:
196
+ print("⚠️ Google Sheet 'known_samples' is empty.")
197
+ return None
198
+
199
+ save_df = pd.DataFrame(data[1:], columns=data[0])
200
+ print("before pipeline, len of save df: ", len(save_df))
201
+ if niche_cases: niche_cases = ", ".join(niche_cases)
202
+ print("this is niche case inside summarize result: ", niche_cases)
203
+ outputs = await pipeline_classify_sample_location_cached(accession, stop_flag, save_df, niche_cases)
204
+ # outputs = {"KU131308":{"isolate":"BRU18",
205
+ # "country":{"brunei":['ncbi','rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples...']},
206
+ # "sample_type":{"modern":['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples']},
207
+ # "query_cost":9.754999999999999e-05,
208
+ # "time_cost":'24.776 seconds',
209
+ # "source":['https://doi.org/10.1007/s00439-015-1620-z'],
210
+ # "file_chunk":"filechunk",
211
+ # "file_all_output":"fileoutput",
212
+ # 'specific location':{'brunei':["some explain"]}}}
213
+ if stop_flag is not None and stop_flag.value:
214
+ print(f"🛑 Skipped {accession} mid-pipeline.")
215
+ return []
216
+ # outputs = {'KU131308': {'isolate':'BRU18',
217
+ # 'country': {'brunei': ['ncbi',
218
+ # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
219
+ # 'sample_type': {'modern':
220
+ # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
221
+ # 'query_cost': 9.754999999999999e-05,
222
+ # 'time_cost': '24.776 seconds',
223
+ # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
224
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
225
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
226
+ except Exception as e:
227
+ return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
228
+
229
+ if accession not in outputs:
230
+ print("no accession in output ", accession)
231
+ return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
232
+
233
+ row_score = []
234
+ rows = []
235
+ save_rows = []
236
+ for key in outputs:
237
+ pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
238
+ checked_sections = ["country", "sample_type"]
239
+ niche_cases = niche_cases.split(", ")
240
+ if niche_cases: checked_sections += niche_cases
241
+ print("checked sections: ", checked_sections)
242
+ for section, results in outputs[key].items():
243
+ pred_output = []#"\n".join(list(results.keys()))
244
+ output_explanation = ""
245
+ print(section, results)
246
+ if section not in checked_sections: continue
247
+ for result, content in results.items():
248
+ if len(result) == 0: result = "unknown"
249
+ if len(content) == 0: output_explanation = "unknown"
250
+ else:
251
+ output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
252
+ pred_output.append(result)
253
+ pred_output = "\n".join(pred_output)
254
+ if section == "country":
255
+ pred_country, country_explanation = pred_output, output_explanation
256
+ elif section == "sample_type":
257
+ pred_sample, sample_explanation = pred_output, output_explanation
258
+ else:
259
+ pred_niche, niche_explanation = pred_output, output_explanation
260
+ if outputs[key]["isolate"].lower()!="unknown":
261
+ label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
262
+ else: label = key
263
+ if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
264
+
265
+ if niche_cases:
266
+ row = {
267
+ "Sample ID": truncate_cell(label or "unknown"),
268
+ "Predicted Country": truncate_cell(pred_country or "unknown"),
269
+ "Country Explanation": truncate_cell(country_explanation or "unknown"),
270
+ "Predicted Sample Type": truncate_cell(pred_sample or "unknown"),
271
+ "Sample Type Explanation": truncate_cell(sample_explanation or "unknown"),
272
+ "Predicted " + niche_cases[0]: truncate_cell(pred_niche or "unknown"),
273
+ niche_cases[0] + " Explanation": truncate_cell(niche_explanation or "unknown"),
274
+ "Sources": truncate_cell("\n".join(outputs[key]["source"]) or "No Links"),
275
+ "Time cost": truncate_cell(outputs[key]["time_cost"])
276
+ }
277
+ #row_score.append(row)
278
+ # rows.append(list(row.values()))
279
+ rows.append(row)
280
+
281
+ save_row = {
282
+ "Sample ID": truncate_cell(label or "unknown"),
283
+ "Predicted Country": truncate_cell(pred_country or "unknown"),
284
+ "Country Explanation": truncate_cell(country_explanation or "unknown"),
285
+ "Predicted Sample Type": truncate_cell(pred_sample or "unknown"),
286
+ "Sample Type Explanation": truncate_cell(sample_explanation or "unknown"),
287
+ "Predicted " + niche_cases[0]: truncate_cell(pred_niche or "unknown"),
288
+ niche_cases[0] + " Explanation": truncate_cell(niche_explanation or "unknown"),
289
+ "Sources": truncate_cell("\n".join(outputs[key]["source"]) or "No Links"),
290
+ "Query_cost": outputs[key]["query_cost"] or "",
291
+ "Time cost": outputs[key]["time_cost"] or "",
292
+ "file_chunk": truncate_cell(outputs[key]["file_chunk"] or ""),
293
+ "file_all_output": truncate_cell(outputs[key]["file_all_output"] or "")
294
+ }
295
+
296
+ #row_score.append(row)
297
+ #save_rows.append(list(save_row.values()))
298
+ save_rows.append(save_row)
299
+ else:
300
+ row = {
301
+ "Sample ID": truncate_cell(label or "unknown"),
302
+ "Predicted Country": truncate_cell(pred_country or "unknown"),
303
+ "Country Explanation": truncate_cell(country_explanation or "unknown"),
304
+ "Predicted Sample Type": truncate_cell(pred_sample or "unknown"),
305
+ "Sample Type Explanation": truncate_cell(sample_explanation or "unknown"),
306
+ "Sources": truncate_cell("\n".join(outputs[key]["source"]) or "No Links"),
307
+ "Time cost": truncate_cell(outputs[key]["time_cost"])
308
+ }
309
+ #row_score.append(row)
310
+ # rows.append(list(row.values()))
311
+ rows.append(row)
312
+ save_row = {
313
+ "Sample ID": truncate_cell(label or "unknown"),
314
+ "Predicted Country": truncate_cell(pred_country or "unknown"),
315
+ "Country Explanation": truncate_cell(country_explanation or "unknown"),
316
+ "Predicted Sample Type": truncate_cell(pred_sample or "unknown"),
317
+ "Sample Type Explanation": truncate_cell(sample_explanation or "unknown"),
318
+ "Sources": truncate_cell("\n".join(outputs[key]["source"]) or "No Links"),
319
+ "Query_cost": outputs[key]["query_cost"] or "",
320
+ "Time cost": outputs[key]["time_cost"] or "",
321
+ "file_chunk": truncate_cell(outputs[key]["file_chunk"] or ""),
322
+ "file_all_output": truncate_cell(outputs[key]["file_all_output"] or "")
323
+ }
324
+
325
+ #row_score.append(row)
326
+ #save_rows.append(list(save_row.values()))
327
+ save_rows.append(save_row)
328
+ print("the final rows: ", rows)
329
+
330
+ try:
331
+ # Prepare as DataFrame
332
+ df_new = pd.DataFrame(save_rows)
333
+ print("done df_new and here are save_rows: ", save_rows)
334
+ # Setup Google Sheets
335
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
336
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
337
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
338
+ client = gspread.authorize(creds)
339
+ spreadsheet = client.open("known_samples")
340
+ sheet = spreadsheet.sheet1
341
+
342
+ # ✅ Load existing data + headers
343
+ existing_data = sheet.get_all_values()
344
+ headers = existing_data[0] if existing_data else []
345
+ existing_df = pd.DataFrame(existing_data[1:], columns=headers) if len(existing_data) > 1 else pd.DataFrame()
346
+
347
+ # Extend headers if new keys appear in save_rows
348
+ print("df_new.col: ", df_new.columns)
349
+ for col in df_new.columns:
350
+ print(col)
351
+ if col not in headers:
352
+ headers.append(col)
353
+ # Add new column header in the sheet
354
+ sheet.update_cell(1, len(headers), col)
355
+
356
+ # ✅ Align DataFrame with sheet headers (fill missing with "")
357
+ df_new = df_new.reindex(columns=headers, fill_value="")
358
+
359
+ # Build lookup: Sample ID → row index
360
+ if "Sample ID" in existing_df.columns:
361
+ id_to_row = {sid: i + 2 for i, sid in enumerate(existing_df["Sample ID"])}
362
+ else:
363
+ id_to_row = {}
364
+
365
+ for _, row in df_new.iterrows():
366
+ sid = row.get("Sample ID", "")
367
+ row_values = [truncate_cell(str(row.get(h, ""))) for h in headers]
368
+ print("row_val of df_new: ", row_values)
369
+ if sid in id_to_row:
370
+ # Update existing row in correct header order
371
+ sheet.update(f"A{id_to_row[sid]}:{chr(64+len(headers))}{id_to_row[sid]}", [row_values])
372
+ else:
373
+ # ✅ Append new row
374
+ sheet.append_row(row_values)
375
+
376
+ print(" Match results safely saved to known_samples with dynamic headers.")
377
+
378
+ except Exception as e:
379
+ print(f"❌ Failed to update known_samples: {e}")
380
+
381
+
382
+ return rows
383
+
384
+ def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
385
+ df_new = pd.DataFrame(all_rows, columns=[
386
+ "Sample ID", "Predicted Country", "Country Explanation",
387
+ "Predicted Sample Type", "Sample Type Explanation",
388
+ "Sources", "Time cost"
389
+ ])
390
+
391
+ if is_resume and os.path.exists(filename):
392
+ try:
393
+ df_old = pd.read_excel(filename)
394
+ except Exception as e:
395
+ print(f"⚠️ Warning reading old Excel file: {e}")
396
+ df_old = pd.DataFrame(columns=df_new.columns)
397
+
398
+ # Set index and update existing rows
399
+ df_old.set_index("Sample ID", inplace=True)
400
+ df_new.set_index("Sample ID", inplace=True)
401
+ df_old.update(df_new)
402
+
403
+ df_combined = df_old.reset_index()
404
+ else:
405
+ # If not resuming or file doesn't exist, just use new rows
406
+ df_combined = df_new
407
+
408
+ try:
409
+ df_combined.to_excel(filename, index=False)
410
+ except Exception as e:
411
+ print(f" Failed to write Excel file {filename}: {e}")
412
+
413
+
414
+ # save the batch input in JSON file
415
+ def save_to_json(all_rows, summary_text, flag_text, filename):
416
+ output_dict = {
417
+ "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
418
+ # "Summary_Text": summary_text,
419
+ # "Ancient_Modern_Flag": flag_text
420
+ }
421
+
422
+ # If all_rows is a DataFrame, convert it
423
+ if isinstance(all_rows, pd.DataFrame):
424
+ output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
425
+
426
+ with open(filename, "w") as external_file:
427
+ json.dump(output_dict, external_file, indent=2)
428
+
429
+ # save the batch input in Text file
430
+ def save_to_txt(all_rows, summary_text, flag_text, filename):
431
+ if isinstance(all_rows, pd.DataFrame):
432
+ detailed_results = all_rows.to_dict(orient="records")
433
+ output = ""
434
+ #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
435
+ output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
436
+ for r in detailed_results:
437
+ output += ",".join([str(v) for v in r.values()]) + "\n\n"
438
+ with open(filename, "w") as f:
439
+ f.write("=== Detailed Results ===\n")
440
+ f.write(output + "\n")
441
+
442
+ # f.write("\n=== Summary ===\n")
443
+ # f.write(summary_text + "\n")
444
+
445
+ # f.write("\n=== Ancient/Modern Flag ===\n")
446
+ # f.write(flag_text + "\n")
447
+
448
+ def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
449
+ tmp_dir = tempfile.mkdtemp()
450
+
451
+ #html_table = all_rows.value # assuming this is stored somewhere
452
+
453
+ # Parse back to DataFrame
454
+ #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
455
+ all_rows = pd.read_html(StringIO(all_rows))[0]
456
+ print(all_rows)
457
+
458
+ if output_type == "Excel":
459
+ file_path = f"{tmp_dir}/batch_output.xlsx"
460
+ save_to_excel(all_rows, summary_text, flag_text, file_path)
461
+ elif output_type == "JSON":
462
+ file_path = f"{tmp_dir}/batch_output.json"
463
+ save_to_json(all_rows, summary_text, flag_text, file_path)
464
+ print("Done with JSON")
465
+ elif output_type == "TXT":
466
+ file_path = f"{tmp_dir}/batch_output.txt"
467
+ save_to_txt(all_rows, summary_text, flag_text, file_path)
468
+ else:
469
+ return gr.update(visible=False) # invalid option
470
+
471
+ return gr.update(value=file_path, visible=True)
472
+ # save cost by checking the known outputs
473
+
474
+ # def check_known_output(accession):
475
+ # if not os.path.exists(KNOWN_OUTPUT_PATH):
476
+ # return None
477
+
478
+ # try:
479
+ # df = pd.read_excel(KNOWN_OUTPUT_PATH)
480
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
481
+ # if match:
482
+ # accession = match.group(0)
483
+
484
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
485
+ # if not matched.empty:
486
+ # return matched.iloc[0].to_dict() # Return the cached row
487
+ # except Exception as e:
488
+ # print(f"⚠️ Failed to load known samples: {e}")
489
+ # return None
490
+
491
+ # def check_known_output(accession):
492
+ # try:
493
+ # # ✅ Load credentials from Hugging Face secret
494
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
495
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
496
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
497
+ # client = gspread.authorize(creds)
498
+
499
+ # # ✅ Open the known_samples sheet
500
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
501
+ # sheet = spreadsheet.sheet1
502
+
503
+ # # Read all rows
504
+ # data = sheet.get_all_values()
505
+ # if not data:
506
+ # return None
507
+
508
+ # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
509
+
510
+ # # Normalize accession pattern
511
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
512
+ # if match:
513
+ # accession = match.group(0)
514
+
515
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
516
+ # if not matched.empty:
517
+ # return matched.iloc[0].to_dict()
518
+
519
+ # except Exception as e:
520
+ # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
521
+ # return None
522
+ # def check_known_output(accession):
523
+ # print("inside check known output function")
524
+ # try:
525
+ # # Load credentials from Hugging Face secret
526
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
527
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
528
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
529
+ # client = gspread.authorize(creds)
530
+
531
+ # spreadsheet = client.open("known_samples")
532
+ # sheet = spreadsheet.sheet1
533
+
534
+ # data = sheet.get_all_values()
535
+ # if not data:
536
+ # print("⚠️ Google Sheet 'known_samples' is empty.")
537
+ # return None
538
+
539
+ # df = pd.DataFrame(data[1:], columns=data[0])
540
+ # if "Sample ID" not in df.columns:
541
+ # print(" Column 'Sample ID' not found in Google Sheet.")
542
+ # return None
543
+
544
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
545
+ # if match:
546
+ # accession = match.group(0)
547
+
548
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
549
+ # if not matched.empty:
550
+ # #return matched.iloc[0].to_dict()
551
+ # row = matched.iloc[0]
552
+ # country = row.get("Predicted Country", "").strip().lower()
553
+ # sample_type = row.get("Predicted Sample Type", "").strip().lower()
554
+
555
+ # if country and country != "unknown" and sample_type and sample_type != "unknown":
556
+ # return row.to_dict()
557
+ # else:
558
+ # print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
559
+ # return None
560
+ # else:
561
+ # print(f"🔍 Accession {accession} not found in known_samples.")
562
+ # return None
563
+
564
+ # except Exception as e:
565
+ # import traceback
566
+ # print("❌ Exception occurred during check_known_output:")
567
+ # traceback.print_exc()
568
+ # return None
569
+
570
+ import os
571
+ import re
572
+ import json
573
+ import time
574
+ import gspread
575
+ import pandas as pd
576
+ from oauth2client.service_account import ServiceAccountCredentials
577
+ from gspread.exceptions import APIError
578
+
579
+ # --- Global cache ---
580
+ _known_samples_cache = None
581
+
582
+ def load_known_samples():
583
+ """Load the Google Sheet 'known_samples' into a Pandas DataFrame and cache it."""
584
+ global _known_samples_cache
585
+ try:
586
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
587
+ scope = [
588
+ 'https://spreadsheets.google.com/feeds',
589
+ 'https://www.googleapis.com/auth/drive'
590
+ ]
591
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
592
+ client = gspread.authorize(creds)
593
+
594
+ sheet = client.open("known_samples").sheet1
595
+ data = sheet.get_all_values()
596
+
597
+ if not data:
598
+ print("⚠️ Google Sheet 'known_samples' is empty.")
599
+ _known_samples_cache = pd.DataFrame()
600
+ else:
601
+ _known_samples_cache = pd.DataFrame(data[1:], columns=data[0])
602
+ print(f"✅ Cached {_known_samples_cache.shape[0]} rows from known_samples")
603
+
604
+ except APIError as e:
605
+ print(f"❌ APIError while loading known_samples: {e}")
606
+ _known_samples_cache = pd.DataFrame()
607
+ except Exception as e:
608
+ import traceback
609
+ print("❌ Exception occurred while loading known_samples:")
610
+ traceback.print_exc()
611
+ _known_samples_cache = pd.DataFrame()
612
+
613
+ def check_known_output(accession, niche_cases=None):
614
+ """Check if an accession exists in the cached 'known_samples' sheet."""
615
+ global _known_samples_cache
616
+ print("inside check known output function")
617
+
618
+ try:
619
+ # Load cache if not already loaded
620
+ if _known_samples_cache is None:
621
+ load_known_samples()
622
+
623
+ if _known_samples_cache.empty:
624
+ print("⚠️ No cached data available.")
625
+ return None
626
+
627
+ # Extract proper accession format (e.g. AB12345)
628
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
629
+ if match:
630
+ accession = match.group(0)
631
+
632
+ matched = _known_samples_cache[
633
+ _known_samples_cache["Sample ID"].str.contains(accession, case=False, na=False)
634
+ ]
635
+
636
+ if not matched.empty:
637
+ row = matched.iloc[0]
638
+ country = row.get("Predicted Country", "").strip().lower()
639
+ sample_type = row.get("Predicted Sample Type", "").strip().lower()
640
+ output_niche = None
641
+ if niche_cases:
642
+ niche_col = "Predicted " + niche_cases[0]
643
+ if niche_col not in _known_samples_cache.columns:
644
+ print(f"⚠️ Niche column '{niche_col}' not found in known_samples. Skipping cache.")
645
+ return None
646
+ output_niche = row.get("Predicted " + niche_cases[0], "").strip().lower()
647
+ if country and country.lower() not in ["","unknown"] and sample_type and sample_type.lower() not in ["","unknown"] and output_niche and output_niche.lower() not in ["","unknown"]:
648
+ print(f"🎯 Found {accession} in cache")
649
+ return row.to_dict()
650
+ else:
651
+ print(f"⚠️ Accession {accession} found but country/sample_type unknown or empty.")
652
+ return None
653
+ else:
654
+ if country and country.lower() not in ["","unknown"] and sample_type and sample_type.lower() not in ["","unknown"]:
655
+ print(f"🎯 Found {accession} in cache")
656
+ return row.to_dict()
657
+ else:
658
+ print(f"⚠️ Accession {accession} found but country/sample_type unknown or empty.")
659
+ return None
660
+ else:
661
+ print(f"🔍 Accession {accession} not found in cache.")
662
+ return None
663
+
664
+ except Exception as e:
665
+ import traceback
666
+ print("❌ Exception occurred during check_known_output:")
667
+ traceback.print_exc()
668
+ return None
669
+
670
+
671
+
672
+ def hash_user_id(user_input):
673
+ return hashlib.sha256(user_input.encode()).hexdigest()
674
+
675
+ # Load and save usage count
676
+
677
+ # def load_user_usage():
678
+ # if not os.path.exists(USER_USAGE_TRACK_FILE):
679
+ # return {}
680
+
681
+ # try:
682
+ # with open(USER_USAGE_TRACK_FILE, "r") as f:
683
+ # content = f.read().strip()
684
+ # if not content:
685
+ # return {} # file is empty
686
+ # return json.loads(content)
687
+ # except (json.JSONDecodeError, ValueError):
688
+ # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
689
+ # return {} # fallback to empty dict
690
+ # def load_user_usage():
691
+ # try:
692
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
693
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
694
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
695
+ # client = gspread.authorize(creds)
696
+
697
+ # sheet = client.open("user_usage_log").sheet1
698
+ # data = sheet.get_all_records() # Assumes columns: email, usage_count
699
+
700
+ # usage = {}
701
+ # for row in data:
702
+ # email = row.get("email", "").strip().lower()
703
+ # count = int(row.get("usage_count", 0))
704
+ # if email:
705
+ # usage[email] = count
706
+ # return usage
707
+ # except Exception as e:
708
+ # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
709
+ # return {}
710
+ # def load_user_usage():
711
+ # try:
712
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
713
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
714
+
715
+ # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
716
+ # if not found:
717
+ # return {} # not found, start fresh
718
+
719
+ # #file_id = found[0]["id"]
720
+ # file_id = found
721
+ # content = pipeline.download_drive_file_content(file_id)
722
+ # return json.loads(content.strip()) if content.strip() else {}
723
+
724
+ # except Exception as e:
725
+ # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
726
+ # return {}
727
+ def load_user_usage():
728
+ try:
729
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
730
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
731
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
732
+ client = gspread.authorize(creds)
733
+
734
+ sheet = client.open("user_usage_log").sheet1
735
+ data = sheet.get_all_values()
736
+ print("data: ", data)
737
+ print("🧪 Raw header row from sheet:", data[0])
738
+ print("🧪 Character codes in each header:")
739
+ for h in data[0]:
740
+ print([ord(c) for c in h])
741
+
742
+ if not data or len(data) < 2:
743
+ print("⚠️ Sheet is empty or missing rows.")
744
+ return {}
745
+
746
+ headers = [h.strip().lower() for h in data[0]]
747
+ if "email" not in headers or "usage_count" not in headers:
748
+ print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
749
+ return {}
750
+
751
+ permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
752
+ df = pd.DataFrame(data[1:], columns=headers)
753
+
754
+ usage = {}
755
+ permitted = {}
756
+ for _, row in df.iterrows():
757
+ email = row.get("email", "").strip().lower()
758
+ try:
759
+ #count = int(row.get("usage_count", 0))
760
+ try:
761
+ count = int(float(row.get("usage_count", 0)))
762
+ except Exception:
763
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
764
+ count = 0
765
+
766
+ if email:
767
+ usage[email] = count
768
+ if permitted_index is not None:
769
+ try:
770
+ permitted_count = int(float(row.get("permitted_samples", 50)))
771
+ permitted[email] = permitted_count
772
+ except:
773
+ permitted[email] = 50
774
+
775
+ except ValueError:
776
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
777
+ return usage, permitted
778
+
779
+ except Exception as e:
780
+ print(f"❌ Error in load_user_usage: {e}")
781
+ return {}, {}
782
+
783
+
784
+
785
+ # def save_user_usage(usage):
786
+ # with open(USER_USAGE_TRACK_FILE, "w") as f:
787
+ # json.dump(usage, f, indent=2)
788
+
789
+ # def save_user_usage(usage_dict):
790
+ # try:
791
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
792
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
793
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
794
+ # client = gspread.authorize(creds)
795
+
796
+ # sheet = client.open("user_usage_log").sheet1
797
+ # sheet.clear() # clear old contents first
798
+
799
+ # # Write header + rows
800
+ # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
801
+ # sheet.update(rows)
802
+ # except Exception as e:
803
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
804
+ # def save_user_usage(usage_dict):
805
+ # try:
806
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
807
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
808
+
809
+ # import tempfile
810
+ # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
811
+ # print("💾 Saving this usage dict:", usage_dict)
812
+ # with open(tmp_path, "w") as f:
813
+ # json.dump(usage_dict, f, indent=2)
814
+
815
+ # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
816
+
817
+ # except Exception as e:
818
+ # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
819
+ # def save_user_usage(usage_dict):
820
+ # try:
821
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
822
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
823
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
824
+ # client = gspread.authorize(creds)
825
+
826
+ # spreadsheet = client.open("user_usage_log")
827
+ # sheet = spreadsheet.sheet1
828
+
829
+ # # Step 1: Convert new usage to DataFrame
830
+ # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
831
+ # df_new["email"] = df_new["email"].str.strip().str.lower()
832
+
833
+ # # Step 2: Load existing data
834
+ # existing_data = sheet.get_all_values()
835
+ # print("🧪 Sheet existing_data:", existing_data)
836
+
837
+ # # Try to load old data
838
+ # if existing_data and len(existing_data[0]) >= 1:
839
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
840
+
841
+ # # Fix missing columns
842
+ # if "email" not in df_old.columns:
843
+ # df_old["email"] = ""
844
+ # if "usage_count" not in df_old.columns:
845
+ # df_old["usage_count"] = 0
846
+
847
+ # df_old["email"] = df_old["email"].str.strip().str.lower()
848
+ # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
849
+ # else:
850
+ # df_old = pd.DataFrame(columns=["email", "usage_count"])
851
+
852
+ # # Step 3: Merge
853
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True)
854
+ # df_combined = df_combined.groupby("email", as_index=False).sum()
855
+
856
+ # # Step 4: Write back
857
+ # sheet.clear()
858
+ # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
859
+ # print("✅ Saved user usage to user_usage_log sheet.")
860
+
861
+ # except Exception as e:
862
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
863
+ def save_user_usage(usage_dict):
864
+ try:
865
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
866
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
867
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
868
+ client = gspread.authorize(creds)
869
+
870
+ spreadsheet = client.open("user_usage_log")
871
+ sheet = spreadsheet.sheet1
872
+
873
+ # Build new df
874
+ df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
875
+ df_new["email"] = df_new["email"].str.strip().str.lower()
876
+ df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
877
+
878
+ # Read existing data
879
+ existing_data = sheet.get_all_values()
880
+ if existing_data and len(existing_data[0]) >= 2:
881
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
882
+ df_old["email"] = df_old["email"].str.strip().str.lower()
883
+ df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
884
+ else:
885
+ df_old = pd.DataFrame(columns=["email", "usage_count"])
886
+
887
+ # Overwrite specific emails only
888
+ df_old = df_old.set_index("email")
889
+ for email, count in usage_dict.items():
890
+ email = email.strip().lower()
891
+ df_old.loc[email, "usage_count"] = count
892
+ df_old = df_old.reset_index()
893
+
894
+ # Save
895
+ sheet.clear()
896
+ sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
897
+ print("✅ Saved user usage to user_usage_log sheet.")
898
+
899
+ except Exception as e:
900
+ print(f"❌ Failed to save user usage to Google Sheets: {e}")
901
+
902
+
903
+
904
+
905
+ # def increment_usage(user_id, num_samples=1):
906
+ # usage = load_user_usage()
907
+ # if user_id not in usage:
908
+ # usage[user_id] = 0
909
+ # usage[user_id] += num_samples
910
+ # save_user_usage(usage)
911
+ # return usage[user_id]
912
+ # def increment_usage(email: str, count: int):
913
+ # usage = load_user_usage()
914
+ # email_key = email.strip().lower()
915
+ # usage[email_key] = usage.get(email_key, 0) + count
916
+ # save_user_usage(usage)
917
+ # return usage[email_key]
918
+ def increment_usage(email: str, count: int = 1):
919
+ usage, permitted = load_user_usage()
920
+ email_key = email.strip().lower()
921
+ #usage[email_key] = usage.get(email_key, 0) + count
922
+ current = usage.get(email_key, 0)
923
+ new_value = current + count
924
+ max_allowed = permitted.get(email_key) or 50
925
+ usage[email_key] = max(current, new_value) # ✅ Prevent overwrite with lower
926
+ print(f"🧪 increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
927
+ print("max allow is: ", max_allowed)
928
+ save_user_usage(usage)
929
+ return usage[email_key], max_allowed
930
+
931
+
932
+ # run the batch
933
+ def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
934
+ stop_flag=None, output_file_path=None,
935
+ limited_acc=50, yield_callback=None):
936
+ if user_email:
937
+ limited_acc += 10
938
+ accessions, error = extract_accessions_from_input(file, raw_text)
939
+ if error:
940
+ #return [], "", "", f"Error: {error}"
941
+ return [], f"Error: {error}", 0, "", ""
942
+ if resume_file:
943
+ accessions = get_incomplete_accessions(resume_file)
944
+ tmp_dir = tempfile.mkdtemp()
945
+ if not output_file_path:
946
+ if resume_file:
947
+ output_file_path = os.path.join(tmp_dir, resume_file)
948
+ else:
949
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
950
+
951
+ all_rows = []
952
+ # all_summaries = []
953
+ # all_flags = []
954
+ progress_lines = []
955
+ warning = ""
956
+ if len(accessions) > limited_acc:
957
+ accessions = accessions[:limited_acc]
958
+ warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
959
+ for i, acc in enumerate(accessions):
960
+ if stop_flag and stop_flag.value:
961
+ line = f"🛑 Stopped at {acc} ({i+1}/{len(accessions)})"
962
+ progress_lines.append(line)
963
+ if yield_callback:
964
+ yield_callback(line)
965
+ print("🛑 User requested stop.")
966
+ break
967
+ print(f"[{i+1}/{len(accessions)}] Processing {acc}")
968
+ try:
969
+ # rows, summary, label, explain = summarize_results(acc)
970
+ rows = summarize_results(acc)
971
+ all_rows.extend(rows)
972
+ # all_summaries.append(f"**{acc}**\n{summary}")
973
+ # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
974
+ #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
975
+ save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
976
+ line = f"✅ Processed {acc} ({i+1}/{len(accessions)})"
977
+ progress_lines.append(line)
978
+ if yield_callback:
979
+ yield_callback(f"✅ Processed {acc} ({i+1}/{len(accessions)})")
980
+ except Exception as e:
981
+ print(f"❌ Failed to process {acc}: {e}")
982
+ continue
983
+ #all_summaries.append(f"**{acc}**: Failed - {e}")
984
+ #progress_lines.append(f" Processed {acc} ({i+1}/{len(accessions)})")
985
+ limited_acc -= 1
986
+ """for row in all_rows:
987
+ source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
988
+
989
+ if source_column.startswith("http"): # Check if the source is a URL
990
+ # Wrap it with HTML anchor tags to make it clickable
991
+ row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
992
+ if not warning:
993
+ warning = f"You only have {limited_acc} left"
994
+ if user_email.strip():
995
+ user_hash = hash_user_id(user_email)
996
+ total_queries = increment_usage(user_hash, len(all_rows))
997
+ else:
998
+ total_queries = 0
999
+ yield_callback("✅ Finished!")
1000
+
1001
+ # summary_text = "\n\n---\n\n".join(all_summaries)
1002
+ # flag_text = "\n\n---\n\n".join(all_flags)
1003
+ #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
1004
+ #return all_rows, gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1005
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
mtdna_classifier.py CHANGED
@@ -1,769 +1,769 @@
1
- # mtDNA Location Classifier MVP (Google Colab)
2
- # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
- import os
4
- #import streamlit as st
5
- import subprocess
6
- import re
7
- from Bio import Entrez
8
- import fitz
9
- import spacy
10
- from spacy.cli import download
11
- from NER.PDF import pdf
12
- from NER.WordDoc import wordDoc
13
- from NER.html import extractHTML
14
- from NER.word2Vec import word2vec
15
- from transformers import pipeline
16
- import urllib.parse, requests
17
- from pathlib import Path
18
- from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
- import model
20
- # Set your email (required by NCBI Entrez)
21
- #Entrez.email = "your-email@example.com"
22
- import nltk
23
-
24
- nltk.download("stopwords")
25
- nltk.download("punkt")
26
- nltk.download('punkt_tab')
27
- # Step 1: Get PubMed ID from Accession using EDirect
28
- from Bio import Entrez, Medline
29
- import re
30
-
31
- Entrez.email = "your_email@example.com"
32
-
33
- # --- Helper Functions (Re-organized and Upgraded) ---
34
-
35
- def fetch_ncbi_metadata(accession_number):
36
- """
37
- Fetches metadata directly from NCBI GenBank using Entrez.
38
- Includes robust error handling and improved field extraction.
39
- Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
- Also attempts to extract ethnicity and sample_type (ancient/modern).
41
-
42
- Args:
43
- accession_number (str): The NCBI accession number (e.g., "ON792208").
44
-
45
- Returns:
46
- dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
- 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
- """
49
- Entrez.email = "your.email@example.com" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
-
51
- country = "unknown"
52
- specific_location = "unknown"
53
- ethnicity = "unknown"
54
- sample_type = "unknown"
55
- collection_date = "unknown"
56
- isolate = "unknown"
57
- title = "unknown"
58
- doi = "unknown"
59
- pubmed_id = None
60
- all_feature = "unknown"
61
-
62
- KNOWN_COUNTRIES = [
63
- "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
- "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
- "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
- "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
- "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
- "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
- "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
- "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
- "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
- "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
- "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
- "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
- "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
- "Yemen", "Zambia", "Zimbabwe"
77
- ]
78
- COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
-
80
- try:
81
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
- record = Entrez.read(handle)
83
- handle.close()
84
-
85
- gb_seq = None
86
- # Validate record structure: It should be a list with at least one element (a dict)
87
- if isinstance(record, list) and len(record) > 0:
88
- if isinstance(record[0], dict):
89
- gb_seq = record[0]
90
- else:
91
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
- else:
93
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
-
95
- # If gb_seq is still None, return defaults
96
- if gb_seq is None:
97
- return {"country": "unknown",
98
- "specific_location": "unknown",
99
- "ethnicity": "unknown",
100
- "sample_type": "unknown",
101
- "collection_date": "unknown",
102
- "isolate": "unknown",
103
- "title": "unknown",
104
- "doi": "unknown",
105
- "pubmed_id": None,
106
- "all_features": "unknown"}
107
-
108
-
109
- # If gb_seq is valid, proceed with extraction
110
- collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
-
112
- references = gb_seq.get("GBSeq_references", [])
113
- for ref in references:
114
- if not pubmed_id:
115
- pubmed_id = ref.get("GBReference_pubmed",None)
116
- if title == "unknown":
117
- title = ref.get("GBReference_title","unknown")
118
- for xref in ref.get("GBReference_xref", []):
119
- if xref.get("GBXref_dbname") == "doi":
120
- doi = xref.get("GBXref_id")
121
- break
122
-
123
- features = gb_seq.get("GBSeq_feature-table", [])
124
-
125
- context_for_flagging = "" # Accumulate text for ancient/modern detection
126
- features_context = ""
127
- for feature in features:
128
- if feature.get("GBFeature_key") == "source":
129
- feature_context = ""
130
- qualifiers = feature.get("GBFeature_quals", [])
131
- found_country = "unknown"
132
- found_specific_location = "unknown"
133
- found_ethnicity = "unknown"
134
-
135
- temp_geo_loc_name = "unknown"
136
- temp_note_origin_locality = "unknown"
137
- temp_country_qual = "unknown"
138
- temp_locality_qual = "unknown"
139
- temp_collection_location_qual = "unknown"
140
- temp_isolation_source_qual = "unknown"
141
- temp_env_sample_qual = "unknown"
142
- temp_pop_qual = "unknown"
143
- temp_organism_qual = "unknown"
144
- temp_specimen_qual = "unknown"
145
- temp_strain_qual = "unknown"
146
-
147
- for qual in qualifiers:
148
- qual_name = qual.get("GBQualifier_name")
149
- qual_value = qual.get("GBQualifier_value")
150
- feature_context += qual_name + ": " + qual_value +"\n"
151
- if qual_name == "collection_date":
152
- collection_date = qual_value
153
- elif qual_name == "isolate":
154
- isolate = qual_value
155
- elif qual_name == "population":
156
- temp_pop_qual = qual_value
157
- elif qual_name == "organism":
158
- temp_organism_qual = qual_value
159
- elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
- temp_specimen_qual = qual_value
161
- elif qual_name == "strain":
162
- temp_strain_qual = qual_value
163
- elif qual_name == "isolation_source":
164
- temp_isolation_source_qual = qual_value
165
- elif qual_name == "environmental_sample":
166
- temp_env_sample_qual = qual_value
167
-
168
- if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
- elif qual_name == "note":
170
- if qual_value.startswith("origin_locality:"):
171
- temp_note_origin_locality = qual_value
172
- context_for_flagging += qual_value + " " # Capture all notes for flagging
173
- elif qual_name == "country": temp_country_qual = qual_value
174
- elif qual_name == "locality": temp_locality_qual = qual_value
175
- elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
-
177
-
178
- # --- Aggregate all relevant info into context_for_flagging ---
179
- context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
- context_for_flagging = context_for_flagging.strip()
181
-
182
- # --- Determine final country and specific_location based on priority ---
183
- if temp_geo_loc_name != "unknown":
184
- parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
- if len(parts) > 1:
186
- found_specific_location = parts[-1]; found_country = parts[0]
187
- else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
- elif temp_note_origin_locality != "unknown":
189
- match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
- if match:
191
- location_string = match.group(1).strip()
192
- parts = [p.strip() for p in location_string.split(':')]
193
- if len(parts) > 1:
194
- #found_country = parts[-1]; found_specific_location = parts[0]
195
- found_country = model.get_country_from_text(temp_note_origin_locality.lower())
196
- if found_country == "unknown":
197
- found_country = parts[0];
198
- found_specific_location = parts[-1]
199
- else: found_country = location_string; found_specific_location = "unknown"
200
- elif temp_locality_qual != "unknown":
201
- found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
202
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
203
- else: found_specific_location = temp_locality_qual; found_country = "unknown"
204
- elif temp_collection_location_qual != "unknown":
205
- found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
206
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
207
- else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
208
- elif temp_isolation_source_qual != "unknown":
209
- found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
210
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
211
- else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
212
- elif temp_env_sample_qual != "unknown":
213
- found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
214
- if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
215
- else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
216
- if found_country == "unknown" and temp_country_qual != "unknown":
217
- found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
218
- if found_country_match: found_country = found_country_match.group(1)
219
-
220
- country = found_country
221
- specific_location = found_specific_location
222
- # --- Determine final ethnicity ---
223
- if temp_pop_qual != "unknown":
224
- found_ethnicity = temp_pop_qual
225
- elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
226
- found_ethnicity = isolate
227
- elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
228
- eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
229
- if eth_match:
230
- found_ethnicity = eth_match.group(1).strip()
231
-
232
- ethnicity = found_ethnicity
233
-
234
- # --- Determine sample_type (ancient/modern) ---
235
- if context_for_flagging:
236
- sample_type, explain = detect_ancient_flag(context_for_flagging)
237
- features_context += feature_context + "\n"
238
- break
239
-
240
- if specific_location != "unknown" and specific_location.lower() == country.lower():
241
- specific_location = "unknown"
242
- if not features_context: features_context = "unknown"
243
- return {"country": country.lower(),
244
- "specific_location": specific_location.lower(),
245
- "ethnicity": ethnicity.lower(),
246
- "sample_type": sample_type.lower(),
247
- "collection_date": collection_date,
248
- "isolate": isolate,
249
- "title": title,
250
- "doi": doi,
251
- "pubmed_id": pubmed_id,
252
- "all_features": features_context}
253
-
254
- except:
255
- print(f"Error fetching NCBI data for {accession_number}")
256
- return {"country": "unknown",
257
- "specific_location": "unknown",
258
- "ethnicity": "unknown",
259
- "sample_type": "unknown",
260
- "collection_date": "unknown",
261
- "isolate": "unknown",
262
- "title": "unknown",
263
- "doi": "unknown",
264
- "pubmed_id": None,
265
- "all_features": "unknown"}
266
-
267
- # --- Helper function for country matching (re-defined from main code to be self-contained) ---
268
- _country_keywords = {
269
- "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
270
- "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
271
- "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
272
- "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
273
- "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
274
- "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
275
- "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
276
- "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
277
- "central india": "India", "east india": "India", "northeast india": "India",
278
- "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
279
- "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
280
- }
281
-
282
- def get_country_from_text(text):
283
- text_lower = text.lower()
284
- for keyword, country in _country_keywords.items():
285
- if keyword in text_lower:
286
- return country
287
- return "unknown"
288
- # The result will be seen as manualLink for the function get_paper_text
289
- # def search_google_custom(query, max_results=3):
290
- # # query should be the title from ncbi or paper/source title
291
- # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
292
- # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
293
- # endpoint = os.environ["SEARCH_ENDPOINT"]
294
- # params = {
295
- # "key": GOOGLE_CSE_API_KEY,
296
- # "cx": GOOGLE_CSE_CX,
297
- # "q": query,
298
- # "num": max_results
299
- # }
300
- # try:
301
- # response = requests.get(endpoint, params=params)
302
- # if response.status_code == 429:
303
- # print("Rate limit hit. Try again later.")
304
- # return []
305
- # response.raise_for_status()
306
- # data = response.json().get("items", [])
307
- # return [item.get("link") for item in data if item.get("link")]
308
- # except Exception as e:
309
- # print("Google CSE error:", e)
310
- # return []
311
-
312
- def search_google_custom(query, max_results=3):
313
- # query should be the title from ncbi or paper/source title
314
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
315
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
316
- endpoint = os.environ["SEARCH_ENDPOINT"]
317
- params = {
318
- "key": GOOGLE_CSE_API_KEY,
319
- "cx": GOOGLE_CSE_CX,
320
- "q": query,
321
- "num": max_results
322
- }
323
- try:
324
- response = requests.get(endpoint, params=params)
325
- if response.status_code == 429:
326
- print("Rate limit hit. Try again later.")
327
- print("try with back up account")
328
- try:
329
- return search_google_custom_backup(query, max_results)
330
- except:
331
- return []
332
- response.raise_for_status()
333
- data = response.json().get("items", [])
334
- return [item.get("link") for item in data if item.get("link")]
335
- except Exception as e:
336
- print("Google CSE error:", e)
337
- return []
338
-
339
- def search_google_custom_backup(query, max_results=3):
340
- # query should be the title from ncbi or paper/source title
341
- GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY_BACKUP"]
342
- GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX_BACKUP"]
343
- endpoint = os.environ["SEARCH_ENDPOINT"]
344
- params = {
345
- "key": GOOGLE_CSE_API_KEY,
346
- "cx": GOOGLE_CSE_CX,
347
- "q": query,
348
- "num": max_results
349
- }
350
- try:
351
- response = requests.get(endpoint, params=params)
352
- if response.status_code == 429:
353
- print("Rate limit hit. Try again later.")
354
- return []
355
- response.raise_for_status()
356
- data = response.json().get("items", [])
357
- return [item.get("link") for item in data if item.get("link")]
358
- except Exception as e:
359
- print("Google CSE error:", e)
360
- return []
361
- # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
362
- # Step 3.1: Extract Text
363
- # sub: download excel file
364
- def download_excel_file(url, save_path="temp.xlsx"):
365
- if "view.officeapps.live.com" in url:
366
- parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
367
- real_url = urllib.parse.unquote(parsed_url["src"][0])
368
- response = requests.get(real_url)
369
- with open(save_path, "wb") as f:
370
- f.write(response.content)
371
- return save_path
372
- elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
373
- response = requests.get(url)
374
- response.raise_for_status() # Raises error if download fails
375
- with open(save_path, "wb") as f:
376
- f.write(response.content)
377
- return save_path
378
- else:
379
- print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
380
- return url
381
- def get_paper_text(doi,id,manualLinks=None):
382
- # create the temporary folder to contain the texts
383
- folder_path = Path("data/"+str(id))
384
- if not folder_path.exists():
385
- cmd = f'mkdir data/{id}'
386
- result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
387
- print("data/"+str(id) +" created.")
388
- else:
389
- print("data/"+str(id) +" already exists.")
390
- saveLinkFolder = "data/"+id
391
-
392
- link = 'https://doi.org/' + doi
393
- '''textsToExtract = { "doiLink":"paperText"
394
- "file1.pdf":"text1",
395
- "file2.doc":"text2",
396
- "file3.xlsx":excelText3'''
397
- textsToExtract = {}
398
- # get the file to create listOfFile for each id
399
- html = extractHTML.HTML("",link)
400
- jsonSM = html.getSupMaterial()
401
- text = ""
402
- links = [link] + sum((jsonSM[key] for key in jsonSM),[])
403
- if manualLinks != None:
404
- links += manualLinks
405
- for l in links:
406
- # get the main paper
407
- name = l.split("/")[-1]
408
- file_path = folder_path / name
409
- if l == link:
410
- text = html.getListSection()
411
- textsToExtract[link] = text
412
- elif l.endswith(".pdf"):
413
- if file_path.is_file():
414
- l = saveLinkFolder + "/" + name
415
- print("File exists.")
416
- p = pdf.PDF(l,saveLinkFolder,doi)
417
- f = p.openPDFFile()
418
- pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
419
- doc = fitz.open(pdf_path)
420
- text = "\n".join([page.get_text() for page in doc])
421
- textsToExtract[l] = text
422
- elif l.endswith(".doc") or l.endswith(".docx"):
423
- d = wordDoc.wordDoc(l,saveLinkFolder)
424
- text = d.extractTextByPage()
425
- textsToExtract[l] = text
426
- elif l.split(".")[-1].lower() in "xlsx":
427
- wc = word2vec.word2Vec()
428
- # download excel file if it not downloaded yet
429
- savePath = saveLinkFolder +"/"+ l.split("/")[-1]
430
- excelPath = download_excel_file(l, savePath)
431
- corpus = wc.tableTransformToCorpusText([],excelPath)
432
- text = ''
433
- for c in corpus:
434
- para = corpus[c]
435
- for words in para:
436
- text += " ".join(words)
437
- textsToExtract[l] = text
438
- # delete folder after finishing getting text
439
- #cmd = f'rm -r data/{id}'
440
- #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
441
- return textsToExtract
442
- # Step 3.2: Extract context
443
- def extract_context(text, keyword, window=500):
444
- # firstly try accession number
445
- idx = text.find(keyword)
446
- if idx == -1:
447
- return "Sample ID not found."
448
- return text[max(0, idx-window): idx+window]
449
- def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
450
- if keep_if is None:
451
- keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
452
-
453
- outputs = ""
454
- text = text.lower()
455
-
456
- # If isolate is provided, prioritize paragraphs that mention it
457
- # If isolate is provided, prioritize paragraphs that mention it
458
- if accession and accession.lower() in text:
459
- if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
460
- outputs += extract_context(text, accession.lower(), window=700)
461
- if isolate and isolate.lower() in text:
462
- if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
463
- outputs += extract_context(text, isolate.lower(), window=700)
464
- for keyword in keep_if:
465
- para = extract_context(text, keyword)
466
- if para and para not in outputs:
467
- outputs += para + "\n"
468
- return outputs
469
- # Step 4: Classification for now (demo purposes)
470
- # 4.1: Using a HuggingFace model (question-answering)
471
- def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
472
- try:
473
- qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
474
- result = qa({"context": context, "question": question})
475
- return result.get("answer", "Unknown")
476
- except Exception as e:
477
- return f"Error: {str(e)}"
478
-
479
- # 4.2: Infer from haplogroup
480
- # Load pre-trained spaCy model for NER
481
- try:
482
- nlp = spacy.load("en_core_web_sm")
483
- except OSError:
484
- download("en_core_web_sm")
485
- nlp = spacy.load("en_core_web_sm")
486
-
487
- # Define the haplogroup-to-region mapping (simple rule-based)
488
- import csv
489
-
490
- def load_haplogroup_mapping(csv_path):
491
- mapping = {}
492
- with open(csv_path) as f:
493
- reader = csv.DictReader(f)
494
- for row in reader:
495
- mapping[row["haplogroup"]] = [row["region"],row["source"]]
496
- return mapping
497
-
498
- # Function to extract haplogroup from the text
499
- def extract_haplogroup(text):
500
- match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
501
- if match:
502
- submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
503
- if submatch:
504
- return submatch.group(0)
505
- else:
506
- return match.group(1) # fallback
507
- fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
508
- if fallback:
509
- return fallback.group(1)
510
- return None
511
-
512
-
513
- # Function to extract location based on NER
514
- def extract_location(text):
515
- doc = nlp(text)
516
- locations = []
517
- for ent in doc.ents:
518
- if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
519
- locations.append(ent.text)
520
- return locations
521
-
522
- # Function to infer location from haplogroup
523
- def infer_location_from_haplogroup(haplogroup):
524
- haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
525
- return haplo_map.get(haplogroup, ["Unknown","Unknown"])
526
-
527
- # Function to classify the mtDNA sample
528
- def classify_mtDNA_sample_from_haplo(text):
529
- # Extract haplogroup
530
- haplogroup = extract_haplogroup(text)
531
- # Extract location based on NER
532
- locations = extract_location(text)
533
- # Infer location based on haplogroup
534
- inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
535
- return {
536
- "source":sourceHaplo,
537
- "locations_found_in_context": locations,
538
- "haplogroup": haplogroup,
539
- "inferred_location": inferred_location
540
-
541
- }
542
- # 4.3 Get from available NCBI
543
- def infer_location_fromNCBI(accession):
544
- try:
545
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
546
- text = handle.read()
547
- handle.close()
548
- match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
549
- if match:
550
- return match.group(2), match.group(0) # This is the value like "Brunei"
551
- return "Not found", "Not found"
552
-
553
- except Exception as e:
554
- print("❌ Entrez error:", e)
555
- return "Not found", "Not found"
556
-
557
- ### ANCIENT/MODERN FLAG
558
- from Bio import Entrez
559
- import re
560
-
561
- def flag_ancient_modern(accession, textsToExtract, isolate=None):
562
- """
563
- Try to classify a sample as Ancient or Modern using:
564
- 1. NCBI accession (if available)
565
- 2. Supplementary text or context fallback
566
- """
567
- context = ""
568
- label, explain = "", ""
569
-
570
- try:
571
- # Check if we can fetch metadata from NCBI using the accession
572
- handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
573
- text = handle.read()
574
- handle.close()
575
-
576
- isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
577
- if isolate_source:
578
- context += isolate_source.group(0) + " "
579
-
580
- specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
581
- if specimen:
582
- context += specimen.group(0) + " "
583
-
584
- if context.strip():
585
- label, explain = detect_ancient_flag(context)
586
- if label!="Unknown":
587
- return label, explain + " from NCBI\n(" + context + ")"
588
-
589
- # If no useful NCBI metadata, check supplementary texts
590
- if textsToExtract:
591
- labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
592
-
593
- for source in textsToExtract:
594
- text_block = textsToExtract[source]
595
- context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
596
- label, explain = detect_ancient_flag(context)
597
-
598
- if label == "Ancient":
599
- labels["ancient"][0] += 1
600
- labels["ancient"][1] += f"{source}:\n{explain}\n\n"
601
- elif label == "Modern":
602
- labels["modern"][0] += 1
603
- labels["modern"][1] += f"{source}:\n{explain}\n\n"
604
- else:
605
- labels["unknown"] += 1
606
-
607
- if max(labels["modern"][0],labels["ancient"][0]) > 0:
608
- if labels["modern"][0] > labels["ancient"][0]:
609
- return "Modern", labels["modern"][1]
610
- else:
611
- return "Ancient", labels["ancient"][1]
612
- else:
613
- return "Unknown", "No strong keywords detected"
614
- else:
615
- print("No DOI or PubMed ID available for inference.")
616
- return "", ""
617
-
618
- except Exception as e:
619
- print("Error:", e)
620
- return "", ""
621
-
622
-
623
- def detect_ancient_flag(context_snippet):
624
- context = context_snippet.lower()
625
-
626
- ancient_keywords = [
627
- "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
628
- "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
629
- "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
630
- "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
631
- ]
632
-
633
- modern_keywords = [
634
- "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
635
- "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
636
- "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
637
- "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
638
- "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
639
- ]
640
-
641
- ancient_hits = [k for k in ancient_keywords if k in context]
642
- modern_hits = [k for k in modern_keywords if k in context]
643
-
644
- if ancient_hits and not modern_hits:
645
- return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
646
- elif modern_hits and not ancient_hits:
647
- return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
648
- elif ancient_hits and modern_hits:
649
- if len(ancient_hits) >= len(modern_hits):
650
- return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
651
- else:
652
- return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
653
-
654
- # Fallback to QA
655
- answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
656
- if answer.startswith("Error"):
657
- return "Unknown", answer
658
- if "ancient" in answer.lower():
659
- return "Ancient", f"Leaning ancient based on QA: {answer}"
660
- elif "modern" in answer.lower():
661
- return "Modern", f"Leaning modern based on QA: {answer}"
662
- else:
663
- return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
664
-
665
- # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
666
- def classify_sample_location(accession):
667
- outputs = {}
668
- keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
669
- # Step 1: get pubmed id and isolate
670
- pubmedID, isolate = get_info_from_accession(accession)
671
- '''if not pubmedID:
672
- return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
673
- if not isolate:
674
- isolate = "UNKNOWN_ISOLATE"
675
- # Step 2: get doi
676
- doi = get_doi_from_pubmed_id(pubmedID)
677
- '''if not doi:
678
- return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
679
- # Step 3: get text
680
- '''textsToExtract = { "doiLink":"paperText"
681
- "file1.pdf":"text1",
682
- "file2.doc":"text2",
683
- "file3.xlsx":excelText3'''
684
- if doi and pubmedID:
685
- textsToExtract = get_paper_text(doi,pubmedID)
686
- else: textsToExtract = {}
687
- '''if not textsToExtract:
688
- return {"error": f"No texts extracted for DOI {doi}"}'''
689
- if isolate not in [None, "UNKNOWN_ISOLATE"]:
690
- label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
691
- else:
692
- label, explain = flag_ancient_modern(accession,textsToExtract)
693
- # Step 4: prediction
694
- outputs[accession] = {}
695
- outputs[isolate] = {}
696
- # 4.0 Infer from NCBI
697
- location, outputNCBI = infer_location_fromNCBI(accession)
698
- NCBI_result = {
699
- "source": "NCBI",
700
- "sample_id": accession,
701
- "predicted_location": location,
702
- "context_snippet": outputNCBI}
703
- outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
704
- if textsToExtract:
705
- long_text = ""
706
- for key in textsToExtract:
707
- text = textsToExtract[key]
708
- # try accession number first
709
- outputs[accession][key] = {}
710
- keyword = accession
711
- context = extract_context(text, keyword, window=500)
712
- # 4.1: Using a HuggingFace model (question-answering)
713
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
714
- qa_result = {
715
- "source": key,
716
- "sample_id": keyword,
717
- "predicted_location": location,
718
- "context_snippet": context
719
- }
720
- outputs[keyword][key]["QAModel"] = qa_result
721
- # 4.2: Infer from haplogroup
722
- haplo_result = classify_mtDNA_sample_from_haplo(context)
723
- outputs[keyword][key]["haplogroup"] = haplo_result
724
- # try isolate
725
- keyword = isolate
726
- outputs[isolate][key] = {}
727
- context = extract_context(text, keyword, window=500)
728
- # 4.1.1: Using a HuggingFace model (question-answering)
729
- location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
730
- qa_result = {
731
- "source": key,
732
- "sample_id": keyword,
733
- "predicted_location": location,
734
- "context_snippet": context
735
- }
736
- outputs[keyword][key]["QAModel"] = qa_result
737
- # 4.2.1: Infer from haplogroup
738
- haplo_result = classify_mtDNA_sample_from_haplo(context)
739
- outputs[keyword][key]["haplogroup"] = haplo_result
740
- # add long text
741
- long_text += text + ". \n"
742
- # 4.3: UpgradeClassify
743
- # try sample_id as accession number
744
- sample_id = accession
745
- if sample_id:
746
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
747
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
748
- if locations!="No clear location found in top matches":
749
- outputs[sample_id]["upgradeClassifier"] = {}
750
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
751
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
752
- "sample_id": sample_id,
753
- "predicted_location": ", ".join(locations),
754
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
755
- }
756
- # try sample_id as isolate name
757
- sample_id = isolate
758
- if sample_id:
759
- filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
760
- locations = infer_location_for_sample(sample_id.upper(), filtered_context)
761
- if locations!="No clear location found in top matches":
762
- outputs[sample_id]["upgradeClassifier"] = {}
763
- outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
764
- "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
765
- "sample_id": sample_id,
766
- "predicted_location": ", ".join(locations),
767
- "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
768
- }
769
  return outputs, label, explain
 
1
+ # mtDNA Location Classifier MVP (Google Colab)
2
+ # Accepts accession number → Fetches PubMed ID + isolate name → Gets abstract → Predicts location
3
+ import os
4
+ #import streamlit as st
5
+ import subprocess
6
+ import re
7
+ from Bio import Entrez
8
+ import fitz
9
+ import spacy
10
+ from spacy.cli import download
11
+ from NER.PDF import pdf
12
+ from NER.WordDoc import wordDoc
13
+ from NER.html import extractHTML
14
+ from NER.word2Vec import word2vec
15
+ from transformers import pipeline
16
+ import urllib.parse, requests
17
+ from pathlib import Path
18
+ from upgradeClassify import filter_context_for_sample, infer_location_for_sample
19
+ import model
20
+ # Set your email (required by NCBI Entrez)
21
+ #Entrez.email = "your-email@example.com"
22
+ import nltk
23
+
24
+ nltk.download("stopwords")
25
+ nltk.download("punkt")
26
+ nltk.download('punkt_tab')
27
+ # Step 1: Get PubMed ID from Accession using EDirect
28
+ from Bio import Entrez, Medline
29
+ import re
30
+
31
+ Entrez.email = "your_email@example.com"
32
+
33
+ # --- Helper Functions (Re-organized and Upgraded) ---
34
+
35
+ def fetch_ncbi_metadata(accession_number):
36
+ """
37
+ Fetches metadata directly from NCBI GenBank using Entrez.
38
+ Includes robust error handling and improved field extraction.
39
+ Prioritizes location extraction from geo_loc_name, then notes, then other qualifiers.
40
+ Also attempts to extract ethnicity and sample_type (ancient/modern).
41
+
42
+ Args:
43
+ accession_number (str): The NCBI accession number (e.g., "ON792208").
44
+
45
+ Returns:
46
+ dict: A dictionary containing 'country', 'specific_location', 'ethnicity',
47
+ 'sample_type', 'collection_date', 'isolate', 'title', 'doi', 'pubmed_id'.
48
+ """
49
+ Entrez.email = "your.email@example.com" # Required by NCBI, REPLACE WITH YOUR EMAIL
50
+
51
+ country = "unknown"
52
+ specific_location = "unknown"
53
+ ethnicity = "unknown"
54
+ sample_type = "unknown"
55
+ collection_date = "unknown"
56
+ isolate = "unknown"
57
+ title = "unknown"
58
+ doi = "unknown"
59
+ pubmed_id = None
60
+ all_feature = "unknown"
61
+
62
+ KNOWN_COUNTRIES = [
63
+ "Afghanistan", "Albania", "Algeria", "Andorra", "Angola", "Antigua and Barbuda", "Argentina", "Armenia", "Australia", "Austria", "Azerbaijan",
64
+ "Bahamas", "Bahrain", "Bangladesh", "Barbados", "Belarus", "Belgium", "Belize", "Benin", "Bhutan", "Bolivia", "Bosnia and Herzegovina", "Botswana", "Brazil", "Brunei", "Bulgaria", "Burkina Faso", "Burundi",
65
+ "Cabo Verde", "Cambodia", "Cameroon", "Canada", "Central African Republic", "Chad", "Chile", "China", "Colombia", "Comoros", "Congo (Brazzaville)", "Congo (Kinshasa)", "Costa Rica", "Croatia", "Cuba", "Cyprus", "Czechia",
66
+ "Denmark", "Djibouti", "Dominica", "Dominican Republic", "Ecuador", "Egypt", "El Salvador", "Equatorial Guinea", "Eritrea", "Estonia", "Eswatini", "Ethiopia",
67
+ "Fiji", "Finland", "France", "Gabon", "Gambia", "Georgia", "Germany", "Ghana", "Greece", "Grenada", "Guatemala", "Guinea", "Guinea-Bissau", "Guyana",
68
+ "Haiti", "Honduras", "Hungary", "Iceland", "India", "Indonesia", "Iran", "Iraq", "Ireland", "Israel", "Italy", "Ivory Coast", "Jamaica", "Japan", "Jordan",
69
+ "Kazakhstan", "Kenya", "Kiribati", "Kosovo", "Kuwait", "Kyrgyzstan", "Laos", "Latvia", "Lebanon", "Lesotho", "Liberia", "Libya", "Liechtenstein", "Lithuania", "Luxembourg",
70
+ "Madagascar", "Malawi", "Malaysia", "Maldives", "Mali", "Malta", "Marshall Islands", "Mauritania", "Mauritius", "Mexico", "Micronesia", "Moldova", "Monaco", "Mongolia", "Montenegro", "Morocco", "Mozambique", "Myanmar",
71
+ "Namibia", "Nauru", "Nepal", "Netherlands", "New Zealand", "Nicaragua", "Niger", "Nigeria", "North Korea", "North Macedonia", "Norway", "Oman",
72
+ "Pakistan", "Palau", "Palestine", "Panama", "Papua New Guinea", "Paraguay", "Peru", "Philippines", "Poland", "Portugal", "Qatar", "Romania", "Russia", "Rwanda",
73
+ "Saint Kitts and Nevis", "Saint Lucia", "Saint Vincent and the Grenadines", "Samoa", "San Marino", "Sao Tome and Principe", "Saudi Arabia", "Senegal", "Serbia", "Seychelles", "Sierra Leone", "Singapore", "Slovakia", "Slovenia", "Solomon Islands", "Somalia", "South Africa", "South Korea", "South Sudan", "Spain", "Sri Lanka", "Sudan", "Suriname", "Sweden", "Switzerland", "Syria",
74
+ "Taiwan", "Tajikistan", "Tanzania", "Thailand", "Timor-Leste", "Togo", "Tonga", "Trinidad and Tobago", "Tunisia", "Turkey", "Turkmenistan", "Tuvalu",
75
+ "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay", "Uzbekistan", "Vanuatu", "Vatican City", "Venezuela", "Vietnam",
76
+ "Yemen", "Zambia", "Zimbabwe"
77
+ ]
78
+ COUNTRY_PATTERN = re.compile(r'\b(' + '|'.join(re.escape(c) for c in KNOWN_COUNTRIES) + r')\b', re.IGNORECASE)
79
+
80
+ try:
81
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
82
+ record = Entrez.read(handle)
83
+ handle.close()
84
+
85
+ gb_seq = None
86
+ # Validate record structure: It should be a list with at least one element (a dict)
87
+ if isinstance(record, list) and len(record) > 0:
88
+ if isinstance(record[0], dict):
89
+ gb_seq = record[0]
90
+ else:
91
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
92
+ else:
93
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
94
+
95
+ # If gb_seq is still None, return defaults
96
+ if gb_seq is None:
97
+ return {"country": "unknown",
98
+ "specific_location": "unknown",
99
+ "ethnicity": "unknown",
100
+ "sample_type": "unknown",
101
+ "collection_date": "unknown",
102
+ "isolate": "unknown",
103
+ "title": "unknown",
104
+ "doi": "unknown",
105
+ "pubmed_id": None,
106
+ "all_features": "unknown"}
107
+
108
+
109
+ # If gb_seq is valid, proceed with extraction
110
+ collection_date = gb_seq.get("GBSeq_create-date","unknown")
111
+
112
+ references = gb_seq.get("GBSeq_references", [])
113
+ for ref in references:
114
+ if not pubmed_id:
115
+ pubmed_id = ref.get("GBReference_pubmed",None)
116
+ if title == "unknown":
117
+ title = ref.get("GBReference_title","unknown")
118
+ for xref in ref.get("GBReference_xref", []):
119
+ if xref.get("GBXref_dbname") == "doi":
120
+ doi = xref.get("GBXref_id")
121
+ break
122
+
123
+ features = gb_seq.get("GBSeq_feature-table", [])
124
+
125
+ context_for_flagging = "" # Accumulate text for ancient/modern detection
126
+ features_context = ""
127
+ for feature in features:
128
+ if feature.get("GBFeature_key") == "source":
129
+ feature_context = ""
130
+ qualifiers = feature.get("GBFeature_quals", [])
131
+ found_country = "unknown"
132
+ found_specific_location = "unknown"
133
+ found_ethnicity = "unknown"
134
+
135
+ temp_geo_loc_name = "unknown"
136
+ temp_note_origin_locality = "unknown"
137
+ temp_country_qual = "unknown"
138
+ temp_locality_qual = "unknown"
139
+ temp_collection_location_qual = "unknown"
140
+ temp_isolation_source_qual = "unknown"
141
+ temp_env_sample_qual = "unknown"
142
+ temp_pop_qual = "unknown"
143
+ temp_organism_qual = "unknown"
144
+ temp_specimen_qual = "unknown"
145
+ temp_strain_qual = "unknown"
146
+
147
+ for qual in qualifiers:
148
+ qual_name = qual.get("GBQualifier_name")
149
+ qual_value = qual.get("GBQualifier_value")
150
+ feature_context += qual_name + ": " + qual_value +"\n"
151
+ if qual_name == "collection_date":
152
+ collection_date = qual_value
153
+ elif qual_name == "isolate":
154
+ isolate = qual_value
155
+ elif qual_name == "population":
156
+ temp_pop_qual = qual_value
157
+ elif qual_name == "organism":
158
+ temp_organism_qual = qual_value
159
+ elif qual_name == "specimen_voucher" or qual_name == "specimen":
160
+ temp_specimen_qual = qual_value
161
+ elif qual_name == "strain":
162
+ temp_strain_qual = qual_value
163
+ elif qual_name == "isolation_source":
164
+ temp_isolation_source_qual = qual_value
165
+ elif qual_name == "environmental_sample":
166
+ temp_env_sample_qual = qual_value
167
+
168
+ if qual_name == "geo_loc_name": temp_geo_loc_name = qual_value
169
+ elif qual_name == "note":
170
+ if qual_value.startswith("origin_locality:"):
171
+ temp_note_origin_locality = qual_value
172
+ context_for_flagging += qual_value + " " # Capture all notes for flagging
173
+ elif qual_name == "country": temp_country_qual = qual_value
174
+ elif qual_name == "locality": temp_locality_qual = qual_value
175
+ elif qual_name == "collection_location": temp_collection_location_qual = qual_value
176
+
177
+
178
+ # --- Aggregate all relevant info into context_for_flagging ---
179
+ context_for_flagging += f" {isolate} {temp_isolation_source_qual} {temp_specimen_qual} {temp_strain_qual} {temp_organism_qual} {temp_geo_loc_name} {temp_collection_location_qual} {temp_env_sample_qual}"
180
+ context_for_flagging = context_for_flagging.strip()
181
+
182
+ # --- Determine final country and specific_location based on priority ---
183
+ if temp_geo_loc_name != "unknown":
184
+ parts = [p.strip() for p in temp_geo_loc_name.split(':')]
185
+ if len(parts) > 1:
186
+ found_specific_location = parts[-1]; found_country = parts[0]
187
+ else: found_country = temp_geo_loc_name; found_specific_location = "unknown"
188
+ elif temp_note_origin_locality != "unknown":
189
+ match = re.search(r"origin_locality:\s*(.*)", temp_note_origin_locality, re.IGNORECASE)
190
+ if match:
191
+ location_string = match.group(1).strip()
192
+ parts = [p.strip() for p in location_string.split(':')]
193
+ if len(parts) > 1:
194
+ #found_country = parts[-1]; found_specific_location = parts[0]
195
+ found_country = model.get_country_from_text(temp_note_origin_locality.lower())
196
+ if found_country == "unknown":
197
+ found_country = parts[0];
198
+ found_specific_location = parts[-1]
199
+ else: found_country = location_string; found_specific_location = "unknown"
200
+ elif temp_locality_qual != "unknown":
201
+ found_country_match = COUNTRY_PATTERN.search(temp_locality_qual)
202
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_locality_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
203
+ else: found_specific_location = temp_locality_qual; found_country = "unknown"
204
+ elif temp_collection_location_qual != "unknown":
205
+ found_country_match = COUNTRY_PATTERN.search(temp_collection_location_qual)
206
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_collection_location_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
207
+ else: found_specific_location = temp_collection_location_qual; found_country = "unknown"
208
+ elif temp_isolation_source_qual != "unknown":
209
+ found_country_match = COUNTRY_PATTERN.search(temp_isolation_source_qual)
210
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_isolation_source_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
211
+ else: found_specific_location = temp_isolation_source_qual; found_country = "unknown"
212
+ elif temp_env_sample_qual != "unknown":
213
+ found_country_match = COUNTRY_PATTERN.search(temp_env_sample_qual)
214
+ if found_country_match: found_country = found_country_match.group(1); temp_loc = re.sub(re.escape(found_country), '', temp_env_sample_qual, flags=re.IGNORECASE).strip().replace(',', '').replace(':', '').replace(';', '').strip(); found_specific_location = temp_loc if temp_loc else "unknown"
215
+ else: found_specific_location = temp_env_sample_qual; found_country = "unknown"
216
+ if found_country == "unknown" and temp_country_qual != "unknown":
217
+ found_country_match = COUNTRY_PATTERN.search(temp_country_qual)
218
+ if found_country_match: found_country = found_country_match.group(1)
219
+
220
+ country = found_country
221
+ specific_location = found_specific_location
222
+ # --- Determine final ethnicity ---
223
+ if temp_pop_qual != "unknown":
224
+ found_ethnicity = temp_pop_qual
225
+ elif isolate != "unknown" and re.fullmatch(r'[A-Za-z\s\-]+', isolate) and get_country_from_text(isolate) == "unknown":
226
+ found_ethnicity = isolate
227
+ elif context_for_flagging != "unknown": # Use the broader context for ethnicity patterns
228
+ eth_match = re.search(r'(?:population|ethnicity|isolate source):\s*([A-Za-z\s\-]+)', context_for_flagging, re.IGNORECASE)
229
+ if eth_match:
230
+ found_ethnicity = eth_match.group(1).strip()
231
+
232
+ ethnicity = found_ethnicity
233
+
234
+ # --- Determine sample_type (ancient/modern) ---
235
+ if context_for_flagging:
236
+ sample_type, explain = detect_ancient_flag(context_for_flagging)
237
+ features_context += feature_context + "\n"
238
+ break
239
+
240
+ if specific_location != "unknown" and specific_location.lower() == country.lower():
241
+ specific_location = "unknown"
242
+ if not features_context: features_context = "unknown"
243
+ return {"country": country.lower(),
244
+ "specific_location": specific_location.lower(),
245
+ "ethnicity": ethnicity.lower(),
246
+ "sample_type": sample_type.lower(),
247
+ "collection_date": collection_date,
248
+ "isolate": isolate,
249
+ "title": title,
250
+ "doi": doi,
251
+ "pubmed_id": pubmed_id,
252
+ "all_features": features_context}
253
+
254
+ except:
255
+ print(f"Error fetching NCBI data for {accession_number}")
256
+ return {"country": "unknown",
257
+ "specific_location": "unknown",
258
+ "ethnicity": "unknown",
259
+ "sample_type": "unknown",
260
+ "collection_date": "unknown",
261
+ "isolate": "unknown",
262
+ "title": "unknown",
263
+ "doi": "unknown",
264
+ "pubmed_id": None,
265
+ "all_features": "unknown"}
266
+
267
+ # --- Helper function for country matching (re-defined from main code to be self-contained) ---
268
+ _country_keywords = {
269
+ "thailand": "Thailand", "laos": "Laos", "cambodia": "Cambodia", "myanmar": "Myanmar",
270
+ "philippines": "Philippines", "indonesia": "Indonesia", "malaysia": "Malaysia",
271
+ "china": "China", "chinese": "China", "india": "India", "taiwan": "Taiwan",
272
+ "vietnam": "Vietnam", "russia": "Russia", "siberia": "Russia", "nepal": "Nepal",
273
+ "japan": "Japan", "sumatra": "Indonesia", "borneu": "Indonesia",
274
+ "yunnan": "China", "tibet": "China", "northern mindanao": "Philippines",
275
+ "west malaysia": "Malaysia", "north thailand": "Thailand", "central thailand": "Thailand",
276
+ "northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
277
+ "central india": "India", "east india": "India", "northeast india": "India",
278
+ "south sibera": "Russia", "mongolia": "China", "beijing": "China", "south korea": "South Korea",
279
+ "north asia": "unknown", "southeast asia": "unknown", "east asia": "unknown"
280
+ }
281
+
282
+ def get_country_from_text(text):
283
+ text_lower = text.lower()
284
+ for keyword, country in _country_keywords.items():
285
+ if keyword in text_lower:
286
+ return country
287
+ return "unknown"
288
+ # The result will be seen as manualLink for the function get_paper_text
289
+ # def search_google_custom(query, max_results=3):
290
+ # # query should be the title from ncbi or paper/source title
291
+ # GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
292
+ # GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
293
+ # endpoint = os.environ["SEARCH_ENDPOINT"]
294
+ # params = {
295
+ # "key": GOOGLE_CSE_API_KEY,
296
+ # "cx": GOOGLE_CSE_CX,
297
+ # "q": query,
298
+ # "num": max_results
299
+ # }
300
+ # try:
301
+ # response = requests.get(endpoint, params=params)
302
+ # if response.status_code == 429:
303
+ # print("Rate limit hit. Try again later.")
304
+ # return []
305
+ # response.raise_for_status()
306
+ # data = response.json().get("items", [])
307
+ # return [item.get("link") for item in data if item.get("link")]
308
+ # except Exception as e:
309
+ # print("Google CSE error:", e)
310
+ # return []
311
+
312
+ def search_google_custom(query, max_results=3):
313
+ # query should be the title from ncbi or paper/source title
314
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY"]
315
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX"]
316
+ endpoint = os.environ["SEARCH_ENDPOINT"]
317
+ params = {
318
+ "key": GOOGLE_CSE_API_KEY,
319
+ "cx": GOOGLE_CSE_CX,
320
+ "q": query,
321
+ "num": max_results
322
+ }
323
+ try:
324
+ response = requests.get(endpoint, params=params)
325
+ if response.status_code == 429:
326
+ print("Rate limit hit. Try again later.")
327
+ print("try with back up account")
328
+ try:
329
+ return search_google_custom_backup(query, max_results)
330
+ except:
331
+ return []
332
+ response.raise_for_status()
333
+ data = response.json().get("items", [])
334
+ return [item.get("link") for item in data if item.get("link")]
335
+ except Exception as e:
336
+ print("Google CSE error:", e)
337
+ return []
338
+
339
+ def search_google_custom_backup(query, max_results=3):
340
+ # query should be the title from ncbi or paper/source title
341
+ GOOGLE_CSE_API_KEY = os.environ["GOOGLE_CSE_API_KEY_BACKUP"]
342
+ GOOGLE_CSE_CX = os.environ["GOOGLE_CSE_CX_BACKUP"]
343
+ endpoint = os.environ["SEARCH_ENDPOINT"]
344
+ params = {
345
+ "key": GOOGLE_CSE_API_KEY,
346
+ "cx": GOOGLE_CSE_CX,
347
+ "q": query,
348
+ "num": max_results
349
+ }
350
+ try:
351
+ response = requests.get(endpoint, params=params)
352
+ if response.status_code == 429:
353
+ print("Rate limit hit. Try again later.")
354
+ return []
355
+ response.raise_for_status()
356
+ data = response.json().get("items", [])
357
+ return [item.get("link") for item in data if item.get("link")]
358
+ except Exception as e:
359
+ print("Google CSE error:", e)
360
+ return []
361
+ # Step 3: Extract Text: Get the paper (html text), sup. materials (pdf, doc, excel) and do text-preprocessing
362
+ # Step 3.1: Extract Text
363
+ # sub: download excel file
364
+ def download_excel_file(url, save_path="temp.xlsx"):
365
+ if "view.officeapps.live.com" in url:
366
+ parsed_url = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
367
+ real_url = urllib.parse.unquote(parsed_url["src"][0])
368
+ response = requests.get(real_url)
369
+ with open(save_path, "wb") as f:
370
+ f.write(response.content)
371
+ return save_path
372
+ elif url.startswith("http") and (url.endswith(".xls") or url.endswith(".xlsx")):
373
+ response = requests.get(url)
374
+ response.raise_for_status() # Raises error if download fails
375
+ with open(save_path, "wb") as f:
376
+ f.write(response.content)
377
+ return save_path
378
+ else:
379
+ print("URL must point directly to an .xls or .xlsx file\n or it already downloaded.")
380
+ return url
381
+ def get_paper_text(doi,id,manualLinks=None):
382
+ # create the temporary folder to contain the texts
383
+ folder_path = Path("data/"+str(id))
384
+ if not folder_path.exists():
385
+ cmd = f'mkdir data/{id}'
386
+ result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
387
+ print("data/"+str(id) +" created.")
388
+ else:
389
+ print("data/"+str(id) +" already exists.")
390
+ saveLinkFolder = "data/"+id
391
+
392
+ link = 'https://doi.org/' + doi
393
+ '''textsToExtract = { "doiLink":"paperText"
394
+ "file1.pdf":"text1",
395
+ "file2.doc":"text2",
396
+ "file3.xlsx":excelText3'''
397
+ textsToExtract = {}
398
+ # get the file to create listOfFile for each id
399
+ html = extractHTML.HTML("",link)
400
+ jsonSM = html.getSupMaterial()
401
+ text = ""
402
+ links = [link] + sum((jsonSM[key] for key in jsonSM),[])
403
+ if manualLinks != None:
404
+ links += manualLinks
405
+ for l in links:
406
+ # get the main paper
407
+ name = l.split("/")[-1]
408
+ file_path = folder_path / name
409
+ if l == link:
410
+ text = html.getListSection()
411
+ textsToExtract[link] = text
412
+ elif l.endswith(".pdf"):
413
+ if file_path.is_file():
414
+ l = saveLinkFolder + "/" + name
415
+ print("File exists.")
416
+ p = pdf.PDF(l,saveLinkFolder,doi)
417
+ f = p.openPDFFile()
418
+ pdf_path = saveLinkFolder + "/" + l.split("/")[-1]
419
+ doc = fitz.open(pdf_path)
420
+ text = "\n".join([page.get_text() for page in doc])
421
+ textsToExtract[l] = text
422
+ elif l.endswith(".doc") or l.endswith(".docx"):
423
+ d = wordDoc.wordDoc(l,saveLinkFolder)
424
+ text = d.extractTextByPage()
425
+ textsToExtract[l] = text
426
+ elif l.split(".")[-1].lower() in "xlsx":
427
+ wc = word2vec.word2Vec()
428
+ # download excel file if it not downloaded yet
429
+ savePath = saveLinkFolder +"/"+ l.split("/")[-1]
430
+ excelPath = download_excel_file(l, savePath)
431
+ corpus = wc.tableTransformToCorpusText([],excelPath)
432
+ text = ''
433
+ for c in corpus:
434
+ para = corpus[c]
435
+ for words in para:
436
+ text += " ".join(words)
437
+ textsToExtract[l] = text
438
+ # delete folder after finishing getting text
439
+ #cmd = f'rm -r data/{id}'
440
+ #result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
441
+ return textsToExtract
442
+ # Step 3.2: Extract context
443
+ def extract_context(text, keyword, window=500):
444
+ # firstly try accession number
445
+ idx = text.find(keyword)
446
+ if idx == -1:
447
+ return "Sample ID not found."
448
+ return text[max(0, idx-window): idx+window]
449
+ def extract_relevant_paragraphs(text, accession, keep_if=None, isolate=None):
450
+ if keep_if is None:
451
+ keep_if = ["sample", "method", "mtdna", "sequence", "collected", "dataset", "supplementary", "table"]
452
+
453
+ outputs = ""
454
+ text = text.lower()
455
+
456
+ # If isolate is provided, prioritize paragraphs that mention it
457
+ # If isolate is provided, prioritize paragraphs that mention it
458
+ if accession and accession.lower() in text:
459
+ if extract_context(text, accession.lower(), window=700) != "Sample ID not found.":
460
+ outputs += extract_context(text, accession.lower(), window=700)
461
+ if isolate and isolate.lower() in text:
462
+ if extract_context(text, isolate.lower(), window=700) != "Sample ID not found.":
463
+ outputs += extract_context(text, isolate.lower(), window=700)
464
+ for keyword in keep_if:
465
+ para = extract_context(text, keyword)
466
+ if para and para not in outputs:
467
+ outputs += para + "\n"
468
+ return outputs
469
+ # Step 4: Classification for now (demo purposes)
470
+ # 4.1: Using a HuggingFace model (question-answering)
471
+ def infer_fromQAModel(context, question="Where is the mtDNA sample from?"):
472
+ try:
473
+ qa = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
474
+ result = qa({"context": context, "question": question})
475
+ return result.get("answer", "Unknown")
476
+ except Exception as e:
477
+ return f"Error: {str(e)}"
478
+
479
+ # 4.2: Infer from haplogroup
480
+ # Load pre-trained spaCy model for NER
481
+ try:
482
+ nlp = spacy.load("en_core_web_sm")
483
+ except OSError:
484
+ download("en_core_web_sm")
485
+ nlp = spacy.load("en_core_web_sm")
486
+
487
+ # Define the haplogroup-to-region mapping (simple rule-based)
488
+ import csv
489
+
490
+ def load_haplogroup_mapping(csv_path):
491
+ mapping = {}
492
+ with open(csv_path) as f:
493
+ reader = csv.DictReader(f)
494
+ for row in reader:
495
+ mapping[row["haplogroup"]] = [row["region"],row["source"]]
496
+ return mapping
497
+
498
+ # Function to extract haplogroup from the text
499
+ def extract_haplogroup(text):
500
+ match = re.search(r'\bhaplogroup\s+([A-Z][0-9a-z]*)\b', text)
501
+ if match:
502
+ submatch = re.match(r'^[A-Z][0-9]*', match.group(1))
503
+ if submatch:
504
+ return submatch.group(0)
505
+ else:
506
+ return match.group(1) # fallback
507
+ fallback = re.search(r'\b([A-Z][0-9a-z]{1,5})\b', text)
508
+ if fallback:
509
+ return fallback.group(1)
510
+ return None
511
+
512
+
513
+ # Function to extract location based on NER
514
+ def extract_location(text):
515
+ doc = nlp(text)
516
+ locations = []
517
+ for ent in doc.ents:
518
+ if ent.label_ == "GPE": # GPE = Geopolitical Entity (location)
519
+ locations.append(ent.text)
520
+ return locations
521
+
522
+ # Function to infer location from haplogroup
523
+ def infer_location_from_haplogroup(haplogroup):
524
+ haplo_map = load_haplogroup_mapping("data/haplogroup_regions_extended.csv")
525
+ return haplo_map.get(haplogroup, ["Unknown","Unknown"])
526
+
527
+ # Function to classify the mtDNA sample
528
+ def classify_mtDNA_sample_from_haplo(text):
529
+ # Extract haplogroup
530
+ haplogroup = extract_haplogroup(text)
531
+ # Extract location based on NER
532
+ locations = extract_location(text)
533
+ # Infer location based on haplogroup
534
+ inferred_location, sourceHaplo = infer_location_from_haplogroup(haplogroup)[0],infer_location_from_haplogroup(haplogroup)[1]
535
+ return {
536
+ "source":sourceHaplo,
537
+ "locations_found_in_context": locations,
538
+ "haplogroup": haplogroup,
539
+ "inferred_location": inferred_location
540
+
541
+ }
542
+ # 4.3 Get from available NCBI
543
+ def infer_location_fromNCBI(accession):
544
+ try:
545
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
546
+ text = handle.read()
547
+ handle.close()
548
+ match = re.search(r'/(geo_loc_name|country|location)\s*=\s*"([^"]+)"', text)
549
+ if match:
550
+ return match.group(2), match.group(0) # This is the value like "Brunei"
551
+ return "Not found", "Not found"
552
+
553
+ except Exception as e:
554
+ print("❌ Entrez error:", e)
555
+ return "Not found", "Not found"
556
+
557
+ ### ANCIENT/MODERN FLAG
558
+ from Bio import Entrez
559
+ import re
560
+
561
+ def flag_ancient_modern(accession, textsToExtract, isolate=None):
562
+ """
563
+ Try to classify a sample as Ancient or Modern using:
564
+ 1. NCBI accession (if available)
565
+ 2. Supplementary text or context fallback
566
+ """
567
+ context = ""
568
+ label, explain = "", ""
569
+
570
+ try:
571
+ # Check if we can fetch metadata from NCBI using the accession
572
+ handle = Entrez.efetch(db="nuccore", id=accession, rettype="medline", retmode="text")
573
+ text = handle.read()
574
+ handle.close()
575
+
576
+ isolate_source = re.search(r'/(isolation_source)\s*=\s*"([^"]+)"', text)
577
+ if isolate_source:
578
+ context += isolate_source.group(0) + " "
579
+
580
+ specimen = re.search(r'/(specimen|specimen_voucher)\s*=\s*"([^"]+)"', text)
581
+ if specimen:
582
+ context += specimen.group(0) + " "
583
+
584
+ if context.strip():
585
+ label, explain = detect_ancient_flag(context)
586
+ if label!="Unknown":
587
+ return label, explain + " from NCBI\n(" + context + ")"
588
+
589
+ # If no useful NCBI metadata, check supplementary texts
590
+ if textsToExtract:
591
+ labels = {"modern": [0, ""], "ancient": [0, ""], "unknown": 0}
592
+
593
+ for source in textsToExtract:
594
+ text_block = textsToExtract[source]
595
+ context = extract_relevant_paragraphs(text_block, accession, isolate=isolate) # Reduce to informative paragraph(s)
596
+ label, explain = detect_ancient_flag(context)
597
+
598
+ if label == "Ancient":
599
+ labels["ancient"][0] += 1
600
+ labels["ancient"][1] += f"{source}:\n{explain}\n\n"
601
+ elif label == "Modern":
602
+ labels["modern"][0] += 1
603
+ labels["modern"][1] += f"{source}:\n{explain}\n\n"
604
+ else:
605
+ labels["unknown"] += 1
606
+
607
+ if max(labels["modern"][0],labels["ancient"][0]) > 0:
608
+ if labels["modern"][0] > labels["ancient"][0]:
609
+ return "Modern", labels["modern"][1]
610
+ else:
611
+ return "Ancient", labels["ancient"][1]
612
+ else:
613
+ return "Unknown", "No strong keywords detected"
614
+ else:
615
+ print("No DOI or PubMed ID available for inference.")
616
+ return "", ""
617
+
618
+ except Exception as e:
619
+ print("Error:", e)
620
+ return "", ""
621
+
622
+
623
+ def detect_ancient_flag(context_snippet):
624
+ context = context_snippet.lower()
625
+
626
+ ancient_keywords = [
627
+ "ancient", "archaeological", "prehistoric", "neolithic", "mesolithic", "paleolithic",
628
+ "bronze age", "iron age", "burial", "tomb", "skeleton", "14c", "radiocarbon", "carbon dating",
629
+ "postmortem damage", "udg treatment", "adna", "degradation", "site", "excavation",
630
+ "archaeological context", "temporal transect", "population replacement", "cal bp", "calbp", "carbon dated"
631
+ ]
632
+
633
+ modern_keywords = [
634
+ "modern", "hospital", "clinical", "consent","blood","buccal","unrelated", "blood sample","buccal sample","informed consent", "donor", "healthy", "patient",
635
+ "genotyping", "screening", "medical", "cohort", "sequencing facility", "ethics approval",
636
+ "we analysed", "we analyzed", "dataset includes", "new sequences", "published data",
637
+ "control cohort", "sink population", "genbank accession", "sequenced", "pipeline",
638
+ "bioinformatic analysis", "samples from", "population genetics", "genome-wide data", "imr collection"
639
+ ]
640
+
641
+ ancient_hits = [k for k in ancient_keywords if k in context]
642
+ modern_hits = [k for k in modern_keywords if k in context]
643
+
644
+ if ancient_hits and not modern_hits:
645
+ return "Ancient", f"Flagged as ancient due to keywords: {', '.join(ancient_hits)}"
646
+ elif modern_hits and not ancient_hits:
647
+ return "Modern", f"Flagged as modern due to keywords: {', '.join(modern_hits)}"
648
+ elif ancient_hits and modern_hits:
649
+ if len(ancient_hits) >= len(modern_hits):
650
+ return "Ancient", f"Mixed context, leaning ancient due to: {', '.join(ancient_hits)}"
651
+ else:
652
+ return "Modern", f"Mixed context, leaning modern due to: {', '.join(modern_hits)}"
653
+
654
+ # Fallback to QA
655
+ answer = infer_fromQAModel(context, question="Are the mtDNA samples ancient or modern? Explain why.")
656
+ if answer.startswith("Error"):
657
+ return "Unknown", answer
658
+ if "ancient" in answer.lower():
659
+ return "Ancient", f"Leaning ancient based on QA: {answer}"
660
+ elif "modern" in answer.lower():
661
+ return "Modern", f"Leaning modern based on QA: {answer}"
662
+ else:
663
+ return "Unknown", f"No strong keywords or QA clues. QA said: {answer}"
664
+
665
+ # STEP 5: Main pipeline: accession -> 1. get pubmed id and isolate -> 2. get doi -> 3. get text -> 4. prediction -> 5. output: inferred location + explanation + confidence score
666
+ def classify_sample_location(accession):
667
+ outputs = {}
668
+ keyword, context, location, qa_result, haplo_result = "", "", "", "", ""
669
+ # Step 1: get pubmed id and isolate
670
+ pubmedID, isolate = get_info_from_accession(accession)
671
+ '''if not pubmedID:
672
+ return {"error": f"Could not retrieve PubMed ID for accession {accession}"}'''
673
+ if not isolate:
674
+ isolate = "UNKNOWN_ISOLATE"
675
+ # Step 2: get doi
676
+ doi = get_doi_from_pubmed_id(pubmedID)
677
+ '''if not doi:
678
+ return {"error": "DOI not found for this accession. Cannot fetch paper or context."}'''
679
+ # Step 3: get text
680
+ '''textsToExtract = { "doiLink":"paperText"
681
+ "file1.pdf":"text1",
682
+ "file2.doc":"text2",
683
+ "file3.xlsx":excelText3'''
684
+ if doi and pubmedID:
685
+ textsToExtract = get_paper_text(doi,pubmedID)
686
+ else: textsToExtract = {}
687
+ '''if not textsToExtract:
688
+ return {"error": f"No texts extracted for DOI {doi}"}'''
689
+ if isolate not in [None, "UNKNOWN_ISOLATE"]:
690
+ label, explain = flag_ancient_modern(accession,textsToExtract,isolate)
691
+ else:
692
+ label, explain = flag_ancient_modern(accession,textsToExtract)
693
+ # Step 4: prediction
694
+ outputs[accession] = {}
695
+ outputs[isolate] = {}
696
+ # 4.0 Infer from NCBI
697
+ location, outputNCBI = infer_location_fromNCBI(accession)
698
+ NCBI_result = {
699
+ "source": "NCBI",
700
+ "sample_id": accession,
701
+ "predicted_location": location,
702
+ "context_snippet": outputNCBI}
703
+ outputs[accession]["NCBI"]= {"NCBI": NCBI_result}
704
+ if textsToExtract:
705
+ long_text = ""
706
+ for key in textsToExtract:
707
+ text = textsToExtract[key]
708
+ # try accession number first
709
+ outputs[accession][key] = {}
710
+ keyword = accession
711
+ context = extract_context(text, keyword, window=500)
712
+ # 4.1: Using a HuggingFace model (question-answering)
713
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
714
+ qa_result = {
715
+ "source": key,
716
+ "sample_id": keyword,
717
+ "predicted_location": location,
718
+ "context_snippet": context
719
+ }
720
+ outputs[keyword][key]["QAModel"] = qa_result
721
+ # 4.2: Infer from haplogroup
722
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
723
+ outputs[keyword][key]["haplogroup"] = haplo_result
724
+ # try isolate
725
+ keyword = isolate
726
+ outputs[isolate][key] = {}
727
+ context = extract_context(text, keyword, window=500)
728
+ # 4.1.1: Using a HuggingFace model (question-answering)
729
+ location = infer_fromQAModel(context, question=f"Where is the mtDNA sample {keyword} from?")
730
+ qa_result = {
731
+ "source": key,
732
+ "sample_id": keyword,
733
+ "predicted_location": location,
734
+ "context_snippet": context
735
+ }
736
+ outputs[keyword][key]["QAModel"] = qa_result
737
+ # 4.2.1: Infer from haplogroup
738
+ haplo_result = classify_mtDNA_sample_from_haplo(context)
739
+ outputs[keyword][key]["haplogroup"] = haplo_result
740
+ # add long text
741
+ long_text += text + ". \n"
742
+ # 4.3: UpgradeClassify
743
+ # try sample_id as accession number
744
+ sample_id = accession
745
+ if sample_id:
746
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
747
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
748
+ if locations!="No clear location found in top matches":
749
+ outputs[sample_id]["upgradeClassifier"] = {}
750
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
751
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
752
+ "sample_id": sample_id,
753
+ "predicted_location": ", ".join(locations),
754
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
755
+ }
756
+ # try sample_id as isolate name
757
+ sample_id = isolate
758
+ if sample_id:
759
+ filtered_context = filter_context_for_sample(sample_id.upper(), long_text, window_size=1)
760
+ locations = infer_location_for_sample(sample_id.upper(), filtered_context)
761
+ if locations!="No clear location found in top matches":
762
+ outputs[sample_id]["upgradeClassifier"] = {}
763
+ outputs[sample_id]["upgradeClassifier"]["upgradeClassifier"] = {
764
+ "source": "From these sources combined: "+ ", ".join(list(textsToExtract.keys())),
765
+ "sample_id": sample_id,
766
+ "predicted_location": ", ".join(locations),
767
+ "context_snippep": "First 1000 words: \n"+ filtered_context[:1000]
768
+ }
769
  return outputs, label, explain
pipeline.py CHANGED
The diff for this file is too large to render. See raw diff
 
smart_fallback.py CHANGED
@@ -1,402 +1,402 @@
1
- from Bio import Entrez, Medline
2
- #import model
3
- import mtdna_classifier
4
- from NER.html import extractHTML
5
- import data_preprocess
6
- import pipeline
7
- import aiohttp
8
- import asyncio
9
- # Setup
10
- def fetch_ncbi(accession_number):
11
- try:
12
- Entrez.email = "your.email@example.com" # Required by NCBI, REPLACE WITH YOUR EMAIL
13
- handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
14
- record = Entrez.read(handle)
15
- handle.close()
16
- outputs = {"authors":"unknown",
17
- "institution":"unknown",
18
- "isolate":"unknown",
19
- "definition":"unknown",
20
- "title":"unknown",
21
- "seq_comment":"unknown",
22
- "collection_date":"unknown" } #'GBSeq_update-date': '25-OCT-2023', 'GBSeq_create-date'
23
- gb_seq = None
24
- # Validate record structure: It should be a list with at least one element (a dict)
25
- if isinstance(record, list) and len(record) > 0:
26
- if isinstance(record[0], dict):
27
- gb_seq = record[0]
28
- else:
29
- print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
30
- # extract collection date
31
- if "GBSeq_create-date" in gb_seq and outputs["collection_date"]=="unknown":
32
- outputs["collection_date"] = gb_seq["GBSeq_create-date"]
33
- else:
34
- if "GBSeq_update-date" in gb_seq and outputs["collection_date"]=="unknown":
35
- outputs["collection_date"] = gb_seq["GBSeq_update-date"]
36
- # extract definition
37
- if "GBSeq_definition" in gb_seq and outputs["definition"]=="unknown":
38
- outputs["definition"] = gb_seq["GBSeq_definition"]
39
- # extract related-reference things
40
- if "GBSeq_references" in gb_seq:
41
- for ref in gb_seq["GBSeq_references"]:
42
- # extract authors
43
- if "GBReference_authors" in ref and outputs["authors"]=="unknown":
44
- outputs["authors"] = "and ".join(ref["GBReference_authors"])
45
- # extract title
46
- if "GBReference_title" in ref and outputs["title"]=="unknown":
47
- outputs["title"] = ref["GBReference_title"]
48
- # extract submitted journal
49
- if 'GBReference_journal' in ref and outputs["institution"]=="unknown":
50
- outputs["institution"] = ref['GBReference_journal']
51
- # extract seq_comment
52
- if 'GBSeq_comment'in gb_seq and outputs["seq_comment"]=="unknown":
53
- outputs["seq_comment"] = gb_seq["GBSeq_comment"]
54
- # extract isolate
55
- if "GBSeq_feature-table" in gb_seq:
56
- if 'GBFeature_quals' in gb_seq["GBSeq_feature-table"][0]:
57
- for ref in gb_seq["GBSeq_feature-table"][0]["GBFeature_quals"]:
58
- if ref['GBQualifier_name'] == "isolate" and outputs["isolate"]=="unknown":
59
- outputs["isolate"] = ref["GBQualifier_value"]
60
- else:
61
- print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
62
-
63
- # If gb_seq is still None, return defaults
64
- if gb_seq is None:
65
- return {"authors":"unknown",
66
- "institution":"unknown",
67
- "isolate":"unknown",
68
- "definition":"unknown",
69
- "title":"unknown",
70
- "seq_comment":"unknown",
71
- "collection_date":"unknown" }
72
- return outputs
73
- except:
74
- print("error in fetching ncbi data")
75
- return {"authors":"unknown",
76
- "institution":"unknown",
77
- "isolate":"unknown",
78
- "definition":"unknown",
79
- "title":"unknown",
80
- "seq_comment":"unknown",
81
- "collection_date":"unknown" }
82
- # Fallback if NCBI crashed or cannot find accession on NBCI
83
- def google_accession_search(accession_id):
84
- """
85
- Search for metadata by accession ID using Google Custom Search.
86
- Falls back to known biological databases and archives.
87
- """
88
- queries = [
89
- f"{accession_id}",
90
- f"{accession_id} site:ncbi.nlm.nih.gov",
91
- f"{accession_id} site:pubmed.ncbi.nlm.nih.gov",
92
- f"{accession_id} site:europepmc.org",
93
- f"{accession_id} site:researchgate.net",
94
- f"{accession_id} mtDNA",
95
- f"{accession_id} mitochondrial DNA"
96
- ]
97
-
98
- links = []
99
- for query in queries:
100
- search_results = mtdna_classifier.search_google_custom(query, 2)
101
- for link in search_results:
102
- if link not in links:
103
- links.append(link)
104
- return links
105
-
106
- # Method 1: Smarter Google
107
- def smart_google_queries(metadata: dict):
108
- queries = []
109
-
110
- # Extract useful fields
111
- isolate = metadata.get("isolate")
112
- author = metadata.get("authors")
113
- institution = metadata.get("institution")
114
- title = metadata.get("title")
115
- combined = []
116
- # Construct queries
117
- if isolate and isolate!="unknown" and isolate!="Unpublished":
118
- queries.append(f'"{isolate}" mitochondrial DNA')
119
- queries.append(f'"{isolate}" site:ncbi.nlm.nih.gov')
120
-
121
- if author and author!="unknown" and author!="Unpublished":
122
- # try:
123
- # author_name = ".".join(author.split(' ')[0].split(".")[:-1]) # Use last name only
124
- # except:
125
- # try:
126
- # author_name = author.split(',')[0] # Use last name only
127
- # except:
128
- # author_name = author
129
- try:
130
- author_name = author.split(',')[0] # Use last name only
131
- except:
132
- author_name = author
133
- queries.append(f'"{author_name}" mitochondrial DNA')
134
- queries.append(f'"{author_name}" mtDNA site:researchgate.net')
135
-
136
- if institution and institution!="unknown" and institution!="Unpublished":
137
- try:
138
- short_inst = ",".join(institution.split(',')[:2]) # Take first part of institution
139
- except:
140
- try:
141
- short_inst = institution.split(',')[0]
142
- except:
143
- short_inst = institution
144
- queries.append(f'"{short_inst}" mtDNA sequence')
145
- #queries.append(f'"{short_inst}" isolate site:nature.com')
146
- if title and title!='unknown' and title!="Unpublished":
147
- if title!="Direct Submission":
148
- queries.append(title)
149
-
150
- return queries
151
-
152
- # def filter_links_by_metadata(search_results, saveLinkFolder, accession=None, stop_flag=None):
153
- # TRUSTED_DOMAINS = [
154
- # "ncbi.nlm.nih.gov",
155
- # "pubmed.ncbi.nlm.nih.gov",
156
- # "pmc.ncbi.nlm.nih.gov",
157
- # "biorxiv.org",
158
- # "researchgate.net",
159
- # "nature.com",
160
- # "sciencedirect.com"
161
- # ]
162
- # if stop_flag is not None and stop_flag.value:
163
- # print(f"🛑 Stop detected {accession}, aborting early...")
164
- # return []
165
- # def is_trusted_link(link):
166
- # for domain in TRUSTED_DOMAINS:
167
- # if domain in link:
168
- # return True
169
- # return False
170
- # def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
171
- # output = []
172
- # keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
173
- # if accession:
174
- # keywords = [accession] + keywords
175
- # title_snippet = link.lower()
176
- # print("save link folder inside this filter function: ", saveLinkFolder)
177
- # success_process, output_process = pipeline.run_with_timeout(data_preprocess.extract_text,args=(link,saveLinkFolder),timeout=60)
178
- # if stop_flag is not None and stop_flag.value:
179
- # print(f"🛑 Stop detected {accession}, aborting early...")
180
- # return []
181
- # if success_process:
182
- # article_text = output_process
183
- # print("yes succeed for getting article text")
184
- # else:
185
- # print("no suceed, fallback to no link")
186
- # article_text = ""
187
- # #article_text = data_preprocess.extract_text(link,saveLinkFolder)
188
- # print("article text")
189
- # #print(article_text)
190
- # if stop_flag is not None and stop_flag.value:
191
- # print(f"🛑 Stop detected {accession}, aborting early...")
192
- # return []
193
- # try:
194
- # ext = link.split(".")[-1].lower()
195
- # if ext not in ["pdf", "docx", "xlsx"]:
196
- # html = extractHTML.HTML("", link)
197
- # if stop_flag is not None and stop_flag.value:
198
- # print(f"🛑 Stop detected {accession}, aborting early...")
199
- # return []
200
- # jsonSM = html.getSupMaterial()
201
- # if jsonSM:
202
- # output += sum((jsonSM[key] for key in jsonSM), [])
203
- # except Exception:
204
- # pass # continue silently
205
- # for keyword in keywords:
206
- # if keyword.lower() in article_text.lower():
207
- # if link not in output:
208
- # output.append([link,keyword.lower()])
209
- # print("link and keyword for article text: ", link, keyword)
210
- # return output
211
- # if keyword.lower() in title_snippet.lower():
212
- # if link not in output:
213
- # output.append([link,keyword.lower()])
214
- # print("link and keyword for title: ", link, keyword)
215
- # return output
216
- # return output
217
-
218
- # filtered = []
219
- # better_filter = []
220
- # if len(search_results) > 0:
221
- # for link in search_results:
222
- # # if is_trusted_link(link):
223
- # # if link not in filtered:
224
- # # filtered.append(link)
225
- # # else:
226
- # print(link)
227
- # if stop_flag is not None and stop_flag.value:
228
- # print(f"🛑 Stop detected {accession}, aborting early...")
229
- # return []
230
- # if link:
231
- # output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
232
- # print("output link: ")
233
- # print(output_link)
234
- # for out_link in output_link:
235
- # if isinstance(out_link,list) and len(out_link) > 1:
236
- # print(out_link)
237
- # kw = out_link[1]
238
- # print("kw and acc: ", kw, accession.lower())
239
- # if accession and kw == accession.lower():
240
- # better_filter.append(out_link[0])
241
- # filtered.append(out_link[0])
242
- # else: filtered.append(out_link)
243
- # print("done with link and here is filter: ",filtered)
244
- # if better_filter:
245
- # filtered = better_filter
246
- # return filtered
247
- async def process_link(session, link, saveLinkFolder, keywords, accession):
248
- output = []
249
- title_snippet = link.lower()
250
-
251
- # use async extractor for web, fallback to sync for local files
252
- if link.startswith("http"):
253
- article_text = await data_preprocess.async_extract_text(link, saveLinkFolder)
254
- else:
255
- article_text = data_preprocess.extract_text(link, saveLinkFolder)
256
-
257
- for keyword in keywords:
258
- if article_text and keyword.lower() in article_text.lower():
259
- output.append([link, keyword.lower(), article_text])
260
- return output
261
- if keyword.lower() in title_snippet:
262
- output.append([link, keyword.lower()])
263
- return output
264
- return output
265
-
266
- async def async_filter_links_by_metadata(search_results, saveLinkFolder, accession=None):
267
- TRUSTED_DOMAINS = [
268
- "ncbi.nlm.nih.gov", "pubmed.ncbi.nlm.nih.gov", "pmc.ncbi.nlm.nih.gov",
269
- "biorxiv.org", "researchgate.net", "nature.com", "sciencedirect.com"
270
- ]
271
-
272
- keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
273
- if accession:
274
- keywords = [accession] + keywords
275
-
276
- filtered, better_filter = {}, {}
277
- print("before doing session")
278
- async with aiohttp.ClientSession() as session:
279
- tasks = []
280
- for link in search_results:
281
- if link:
282
- print("link: ", link)
283
- tasks.append(process_link(session, link, saveLinkFolder, keywords, accession))
284
- print("done")
285
- results = await asyncio.gather(*tasks)
286
- print("outside session")
287
- # merge results
288
- for output_link in results:
289
- for out_link in output_link:
290
- if isinstance(out_link, list) and len(out_link) > 1:
291
- kw = out_link[1]
292
- if accession and kw == accession.lower():
293
- if len(out_link) == 2:
294
- better_filter[out_link[0]] = ""
295
- elif len(out_link) == 3:
296
- better_filter[out_link[0]] = out_link[2]
297
- if len(out_link) == 2:
298
- better_filter[out_link[0]] = ""
299
- elif len(out_link) == 3:
300
- better_filter[out_link[0]] = out_link[2]
301
- else:
302
- filtered[out_link] = ""
303
-
304
- return better_filter or filtered
305
-
306
- def filter_links_by_metadata(search_results, saveLinkFolder, accession=None):
307
- TRUSTED_DOMAINS = [
308
- "ncbi.nlm.nih.gov",
309
- "pubmed.ncbi.nlm.nih.gov",
310
- "pmc.ncbi.nlm.nih.gov",
311
- "biorxiv.org",
312
- "researchgate.net",
313
- "nature.com",
314
- "sciencedirect.com"
315
- ]
316
- def is_trusted_link(link):
317
- for domain in TRUSTED_DOMAINS:
318
- if domain in link:
319
- return True
320
- return False
321
- def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
322
- output = []
323
- keywords = ["mtDNA", "mitochondrial", "Homo sapiens"]
324
- #keywords = ["mtDNA", "mitochondrial"]
325
- if accession:
326
- keywords = [accession] + keywords
327
- title_snippet = link.lower()
328
- #print("save link folder inside this filter function: ", saveLinkFolder)
329
- article_text = data_preprocess.extract_text(link,saveLinkFolder)
330
- print("article text done")
331
- #print(article_text)
332
- try:
333
- ext = link.split(".")[-1].lower()
334
- if ext not in ["pdf", "docx", "xlsx"]:
335
- html = extractHTML.HTML("", link)
336
- jsonSM = html.getSupMaterial()
337
- if jsonSM:
338
- output += sum((jsonSM[key] for key in jsonSM), [])
339
- except Exception:
340
- pass # continue silently
341
- for keyword in keywords:
342
- if article_text:
343
- if keyword.lower() in article_text.lower():
344
- if link not in output:
345
- output.append([link,keyword.lower(), article_text])
346
- return output
347
- if keyword.lower() in title_snippet.lower():
348
- if link not in output:
349
- output.append([link,keyword.lower()])
350
- print("link and keyword for title: ", link, keyword)
351
- return output
352
- return output
353
-
354
- filtered = {}
355
- better_filter = {}
356
- if len(search_results) > 0:
357
- print(search_results)
358
- for link in search_results:
359
- # if is_trusted_link(link):
360
- # if link not in filtered:
361
- # filtered.append(link)
362
- # else:
363
- print(link)
364
- if link:
365
- output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
366
- print("output link: ")
367
- print(output_link)
368
- for out_link in output_link:
369
- if isinstance(out_link,list) and len(out_link) > 1:
370
- print(out_link)
371
- kw = out_link[1]
372
- if accession and kw == accession.lower():
373
- if len(out_link) == 2:
374
- better_filter[out_link[0]] = ""
375
- elif len(out_link) == 3:
376
- # save article
377
- better_filter[out_link[0]] = out_link[2]
378
- if len(out_link) == 2:
379
- better_filter[out_link[0]] = ""
380
- elif len(out_link) == 3:
381
- # save article
382
- better_filter[out_link[0]] = out_link[2]
383
- else: filtered[out_link] = ""
384
- print("done with link and here is filter: ",filtered)
385
- if better_filter:
386
- filtered = better_filter
387
- return filtered
388
-
389
- def smart_google_search(metadata):
390
- queries = smart_google_queries(metadata)
391
- links = []
392
- for q in queries:
393
- #print("\n🔍 Query:", q)
394
- results = mtdna_classifier.search_google_custom(q,2)
395
- for link in results:
396
- #print(f"- {link}")
397
- if link not in links:
398
- links.append(link)
399
- #filter_links = filter_links_by_metadata(links)
400
- return links
401
- # Method 2: Prompt LLM better or better ai search api with all
402
  # the total information from even ncbi and all search
 
1
+ from Bio import Entrez, Medline
2
+ #import model
3
+ import mtdna_classifier
4
+ from NER.html import extractHTML
5
+ import data_preprocess
6
+ import pipeline
7
+ import aiohttp
8
+ import asyncio
9
+ # Setup
10
+ def fetch_ncbi(accession_number):
11
+ try:
12
+ Entrez.email = "your.email@example.com" # Required by NCBI, REPLACE WITH YOUR EMAIL
13
+ handle = Entrez.efetch(db="nucleotide", id=str(accession_number), rettype="gb", retmode="xml")
14
+ record = Entrez.read(handle)
15
+ handle.close()
16
+ outputs = {"authors":"unknown",
17
+ "institution":"unknown",
18
+ "isolate":"unknown",
19
+ "definition":"unknown",
20
+ "title":"unknown",
21
+ "seq_comment":"unknown",
22
+ "collection_date":"unknown" } #'GBSeq_update-date': '25-OCT-2023', 'GBSeq_create-date'
23
+ gb_seq = None
24
+ # Validate record structure: It should be a list with at least one element (a dict)
25
+ if isinstance(record, list) and len(record) > 0:
26
+ if isinstance(record[0], dict):
27
+ gb_seq = record[0]
28
+ else:
29
+ print(f"Warning: record[0] is not a dictionary for {accession_number}. Type: {type(record[0])}")
30
+ # extract collection date
31
+ if "GBSeq_create-date" in gb_seq and outputs["collection_date"]=="unknown":
32
+ outputs["collection_date"] = gb_seq["GBSeq_create-date"]
33
+ else:
34
+ if "GBSeq_update-date" in gb_seq and outputs["collection_date"]=="unknown":
35
+ outputs["collection_date"] = gb_seq["GBSeq_update-date"]
36
+ # extract definition
37
+ if "GBSeq_definition" in gb_seq and outputs["definition"]=="unknown":
38
+ outputs["definition"] = gb_seq["GBSeq_definition"]
39
+ # extract related-reference things
40
+ if "GBSeq_references" in gb_seq:
41
+ for ref in gb_seq["GBSeq_references"]:
42
+ # extract authors
43
+ if "GBReference_authors" in ref and outputs["authors"]=="unknown":
44
+ outputs["authors"] = "and ".join(ref["GBReference_authors"])
45
+ # extract title
46
+ if "GBReference_title" in ref and outputs["title"]=="unknown":
47
+ outputs["title"] = ref["GBReference_title"]
48
+ # extract submitted journal
49
+ if 'GBReference_journal' in ref and outputs["institution"]=="unknown":
50
+ outputs["institution"] = ref['GBReference_journal']
51
+ # extract seq_comment
52
+ if 'GBSeq_comment'in gb_seq and outputs["seq_comment"]=="unknown":
53
+ outputs["seq_comment"] = gb_seq["GBSeq_comment"]
54
+ # extract isolate
55
+ if "GBSeq_feature-table" in gb_seq:
56
+ if 'GBFeature_quals' in gb_seq["GBSeq_feature-table"][0]:
57
+ for ref in gb_seq["GBSeq_feature-table"][0]["GBFeature_quals"]:
58
+ if ref['GBQualifier_name'] == "isolate" and outputs["isolate"]=="unknown":
59
+ outputs["isolate"] = ref["GBQualifier_value"]
60
+ else:
61
+ print(f"Warning: No valid record or empty record list from NCBI for {accession_number}.")
62
+
63
+ # If gb_seq is still None, return defaults
64
+ if gb_seq is None:
65
+ return {"authors":"unknown",
66
+ "institution":"unknown",
67
+ "isolate":"unknown",
68
+ "definition":"unknown",
69
+ "title":"unknown",
70
+ "seq_comment":"unknown",
71
+ "collection_date":"unknown" }
72
+ return outputs
73
+ except:
74
+ print("error in fetching ncbi data")
75
+ return {"authors":"unknown",
76
+ "institution":"unknown",
77
+ "isolate":"unknown",
78
+ "definition":"unknown",
79
+ "title":"unknown",
80
+ "seq_comment":"unknown",
81
+ "collection_date":"unknown" }
82
+ # Fallback if NCBI crashed or cannot find accession on NBCI
83
+ def google_accession_search(accession_id):
84
+ """
85
+ Search for metadata by accession ID using Google Custom Search.
86
+ Falls back to known biological databases and archives.
87
+ """
88
+ queries = [
89
+ f"{accession_id}",
90
+ f"{accession_id} site:ncbi.nlm.nih.gov",
91
+ f"{accession_id} site:pubmed.ncbi.nlm.nih.gov",
92
+ f"{accession_id} site:europepmc.org",
93
+ f"{accession_id} site:researchgate.net",
94
+ f"{accession_id} mtDNA",
95
+ f"{accession_id} mitochondrial DNA"
96
+ ]
97
+
98
+ links = []
99
+ for query in queries:
100
+ search_results = mtdna_classifier.search_google_custom(query, 2)
101
+ for link in search_results:
102
+ if link not in links:
103
+ links.append(link)
104
+ return links
105
+
106
+ # Method 1: Smarter Google
107
+ def smart_google_queries(metadata: dict):
108
+ queries = []
109
+
110
+ # Extract useful fields
111
+ isolate = metadata.get("isolate")
112
+ author = metadata.get("authors")
113
+ institution = metadata.get("institution")
114
+ title = metadata.get("title")
115
+ combined = []
116
+ # Construct queries
117
+ if isolate and isolate!="unknown" and isolate!="Unpublished":
118
+ queries.append(f'"{isolate}" mitochondrial DNA')
119
+ queries.append(f'"{isolate}" site:ncbi.nlm.nih.gov')
120
+
121
+ if author and author!="unknown" and author!="Unpublished":
122
+ # try:
123
+ # author_name = ".".join(author.split(' ')[0].split(".")[:-1]) # Use last name only
124
+ # except:
125
+ # try:
126
+ # author_name = author.split(',')[0] # Use last name only
127
+ # except:
128
+ # author_name = author
129
+ try:
130
+ author_name = author.split(',')[0] # Use last name only
131
+ except:
132
+ author_name = author
133
+ queries.append(f'"{author_name}" mitochondrial DNA')
134
+ queries.append(f'"{author_name}" mtDNA site:researchgate.net')
135
+
136
+ if institution and institution!="unknown" and institution!="Unpublished":
137
+ try:
138
+ short_inst = ",".join(institution.split(',')[:2]) # Take first part of institution
139
+ except:
140
+ try:
141
+ short_inst = institution.split(',')[0]
142
+ except:
143
+ short_inst = institution
144
+ queries.append(f'"{short_inst}" mtDNA sequence')
145
+ #queries.append(f'"{short_inst}" isolate site:nature.com')
146
+ if title and title!='unknown' and title!="Unpublished":
147
+ if title!="Direct Submission":
148
+ queries.append(title)
149
+
150
+ return queries
151
+
152
+ # def filter_links_by_metadata(search_results, saveLinkFolder, accession=None, stop_flag=None):
153
+ # TRUSTED_DOMAINS = [
154
+ # "ncbi.nlm.nih.gov",
155
+ # "pubmed.ncbi.nlm.nih.gov",
156
+ # "pmc.ncbi.nlm.nih.gov",
157
+ # "biorxiv.org",
158
+ # "researchgate.net",
159
+ # "nature.com",
160
+ # "sciencedirect.com"
161
+ # ]
162
+ # if stop_flag is not None and stop_flag.value:
163
+ # print(f"🛑 Stop detected {accession}, aborting early...")
164
+ # return []
165
+ # def is_trusted_link(link):
166
+ # for domain in TRUSTED_DOMAINS:
167
+ # if domain in link:
168
+ # return True
169
+ # return False
170
+ # def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
171
+ # output = []
172
+ # keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
173
+ # if accession:
174
+ # keywords = [accession] + keywords
175
+ # title_snippet = link.lower()
176
+ # print("save link folder inside this filter function: ", saveLinkFolder)
177
+ # success_process, output_process = pipeline.run_with_timeout(data_preprocess.extract_text,args=(link,saveLinkFolder),timeout=60)
178
+ # if stop_flag is not None and stop_flag.value:
179
+ # print(f"🛑 Stop detected {accession}, aborting early...")
180
+ # return []
181
+ # if success_process:
182
+ # article_text = output_process
183
+ # print("yes succeed for getting article text")
184
+ # else:
185
+ # print("no suceed, fallback to no link")
186
+ # article_text = ""
187
+ # #article_text = data_preprocess.extract_text(link,saveLinkFolder)
188
+ # print("article text")
189
+ # #print(article_text)
190
+ # if stop_flag is not None and stop_flag.value:
191
+ # print(f"🛑 Stop detected {accession}, aborting early...")
192
+ # return []
193
+ # try:
194
+ # ext = link.split(".")[-1].lower()
195
+ # if ext not in ["pdf", "docx", "xlsx"]:
196
+ # html = extractHTML.HTML("", link)
197
+ # if stop_flag is not None and stop_flag.value:
198
+ # print(f"🛑 Stop detected {accession}, aborting early...")
199
+ # return []
200
+ # jsonSM = html.getSupMaterial()
201
+ # if jsonSM:
202
+ # output += sum((jsonSM[key] for key in jsonSM), [])
203
+ # except Exception:
204
+ # pass # continue silently
205
+ # for keyword in keywords:
206
+ # if keyword.lower() in article_text.lower():
207
+ # if link not in output:
208
+ # output.append([link,keyword.lower()])
209
+ # print("link and keyword for article text: ", link, keyword)
210
+ # return output
211
+ # if keyword.lower() in title_snippet.lower():
212
+ # if link not in output:
213
+ # output.append([link,keyword.lower()])
214
+ # print("link and keyword for title: ", link, keyword)
215
+ # return output
216
+ # return output
217
+
218
+ # filtered = []
219
+ # better_filter = []
220
+ # if len(search_results) > 0:
221
+ # for link in search_results:
222
+ # # if is_trusted_link(link):
223
+ # # if link not in filtered:
224
+ # # filtered.append(link)
225
+ # # else:
226
+ # print(link)
227
+ # if stop_flag is not None and stop_flag.value:
228
+ # print(f"🛑 Stop detected {accession}, aborting early...")
229
+ # return []
230
+ # if link:
231
+ # output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
232
+ # print("output link: ")
233
+ # print(output_link)
234
+ # for out_link in output_link:
235
+ # if isinstance(out_link,list) and len(out_link) > 1:
236
+ # print(out_link)
237
+ # kw = out_link[1]
238
+ # print("kw and acc: ", kw, accession.lower())
239
+ # if accession and kw == accession.lower():
240
+ # better_filter.append(out_link[0])
241
+ # filtered.append(out_link[0])
242
+ # else: filtered.append(out_link)
243
+ # print("done with link and here is filter: ",filtered)
244
+ # if better_filter:
245
+ # filtered = better_filter
246
+ # return filtered
247
+ async def process_link(session, link, saveLinkFolder, keywords, accession):
248
+ output = []
249
+ title_snippet = link.lower()
250
+
251
+ # use async extractor for web, fallback to sync for local files
252
+ if link.startswith("http"):
253
+ article_text = await data_preprocess.async_extract_text(link, saveLinkFolder)
254
+ else:
255
+ article_text = data_preprocess.extract_text(link, saveLinkFolder)
256
+
257
+ for keyword in keywords:
258
+ if article_text and keyword.lower() in article_text.lower():
259
+ output.append([link, keyword.lower(), article_text])
260
+ return output
261
+ if keyword.lower() in title_snippet:
262
+ output.append([link, keyword.lower()])
263
+ return output
264
+ return output
265
+
266
+ async def async_filter_links_by_metadata(search_results, saveLinkFolder, accession=None):
267
+ TRUSTED_DOMAINS = [
268
+ "ncbi.nlm.nih.gov", "pubmed.ncbi.nlm.nih.gov", "pmc.ncbi.nlm.nih.gov",
269
+ "biorxiv.org", "researchgate.net", "nature.com", "sciencedirect.com"
270
+ ]
271
+
272
+ keywords = ["mtDNA", "mitochondrial", "accession", "isolate", "Homo sapiens", "sequence"]
273
+ if accession:
274
+ keywords = [accession] + keywords
275
+
276
+ filtered, better_filter = {}, {}
277
+ print("before doing session")
278
+ async with aiohttp.ClientSession() as session:
279
+ tasks = []
280
+ for link in search_results:
281
+ if link:
282
+ print("link: ", link)
283
+ tasks.append(process_link(session, link, saveLinkFolder, keywords, accession))
284
+ print("done")
285
+ results = await asyncio.gather(*tasks)
286
+ print("outside session")
287
+ # merge results
288
+ for output_link in results:
289
+ for out_link in output_link:
290
+ if isinstance(out_link, list) and len(out_link) > 1:
291
+ kw = out_link[1]
292
+ if accession and kw == accession.lower():
293
+ if len(out_link) == 2:
294
+ better_filter[out_link[0]] = ""
295
+ elif len(out_link) == 3:
296
+ better_filter[out_link[0]] = out_link[2]
297
+ if len(out_link) == 2:
298
+ better_filter[out_link[0]] = ""
299
+ elif len(out_link) == 3:
300
+ better_filter[out_link[0]] = out_link[2]
301
+ else:
302
+ filtered[out_link] = ""
303
+
304
+ return better_filter or filtered
305
+
306
+ def filter_links_by_metadata(search_results, saveLinkFolder, accession=None):
307
+ TRUSTED_DOMAINS = [
308
+ "ncbi.nlm.nih.gov",
309
+ "pubmed.ncbi.nlm.nih.gov",
310
+ "pmc.ncbi.nlm.nih.gov",
311
+ "biorxiv.org",
312
+ "researchgate.net",
313
+ "nature.com",
314
+ "sciencedirect.com"
315
+ ]
316
+ def is_trusted_link(link):
317
+ for domain in TRUSTED_DOMAINS:
318
+ if domain in link:
319
+ return True
320
+ return False
321
+ def is_relevant_title_snippet(link, saveLinkFolder, accession=None):
322
+ output = []
323
+ keywords = ["mtDNA", "mitochondrial", "Homo sapiens"]
324
+ #keywords = ["mtDNA", "mitochondrial"]
325
+ if accession:
326
+ keywords = [accession] + keywords
327
+ title_snippet = link.lower()
328
+ #print("save link folder inside this filter function: ", saveLinkFolder)
329
+ article_text = data_preprocess.extract_text(link,saveLinkFolder)
330
+ print("article text done")
331
+ #print(article_text)
332
+ try:
333
+ ext = link.split(".")[-1].lower()
334
+ if ext not in ["pdf", "docx", "xlsx"]:
335
+ html = extractHTML.HTML("", link)
336
+ jsonSM = html.getSupMaterial()
337
+ if jsonSM:
338
+ output += sum((jsonSM[key] for key in jsonSM), [])
339
+ except Exception:
340
+ pass # continue silently
341
+ for keyword in keywords:
342
+ if article_text:
343
+ if keyword.lower() in article_text.lower():
344
+ if link not in output:
345
+ output.append([link,keyword.lower(), article_text])
346
+ return output
347
+ if keyword.lower() in title_snippet.lower():
348
+ if link not in output:
349
+ output.append([link,keyword.lower()])
350
+ print("link and keyword for title: ", link, keyword)
351
+ return output
352
+ return output
353
+
354
+ filtered = {}
355
+ better_filter = {}
356
+ if len(search_results) > 0:
357
+ print(search_results)
358
+ for link in search_results:
359
+ # if is_trusted_link(link):
360
+ # if link not in filtered:
361
+ # filtered.append(link)
362
+ # else:
363
+ print(link)
364
+ if link:
365
+ output_link = is_relevant_title_snippet(link,saveLinkFolder, accession)
366
+ print("output link: ")
367
+ print(output_link)
368
+ for out_link in output_link:
369
+ if isinstance(out_link,list) and len(out_link) > 1:
370
+ print(out_link)
371
+ kw = out_link[1]
372
+ if accession and kw == accession.lower():
373
+ if len(out_link) == 2:
374
+ better_filter[out_link[0]] = ""
375
+ elif len(out_link) == 3:
376
+ # save article
377
+ better_filter[out_link[0]] = out_link[2]
378
+ if len(out_link) == 2:
379
+ better_filter[out_link[0]] = ""
380
+ elif len(out_link) == 3:
381
+ # save article
382
+ better_filter[out_link[0]] = out_link[2]
383
+ else: filtered[out_link] = ""
384
+ print("done with link and here is filter: ",filtered)
385
+ if better_filter:
386
+ filtered = better_filter
387
+ return filtered
388
+
389
+ def smart_google_search(metadata):
390
+ queries = smart_google_queries(metadata)
391
+ links = []
392
+ for q in queries:
393
+ #print("\n🔍 Query:", q)
394
+ results = mtdna_classifier.search_google_custom(q,2)
395
+ for link in results:
396
+ #print(f"- {link}")
397
+ if link not in links:
398
+ links.append(link)
399
+ #filter_links = filter_links_by_metadata(links)
400
+ return links
401
+ # Method 2: Prompt LLM better or better ai search api with all
402
  # the total information from even ncbi and all search