VyLala commited on
Commit
9be33a8
Β·
verified Β·
1 Parent(s): d9ef8dd

Update mtdna_backend.py

Browse files
Files changed (1) hide show
  1. mtdna_backend.py +1032 -944
mtdna_backend.py CHANGED
@@ -1,945 +1,1033 @@
1
- import gradio as gr
2
- from collections import Counter
3
- import csv
4
- import os
5
- from functools import lru_cache
6
- #import app
7
- from mtdna_classifier import classify_sample_location
8
- import data_preprocess, model, pipeline
9
- import subprocess
10
- import json
11
- import pandas as pd
12
- import io
13
- import re
14
- import tempfile
15
- import gspread
16
- from oauth2client.service_account import ServiceAccountCredentials
17
- from io import StringIO
18
- import hashlib
19
- import threading
20
-
21
- # @lru_cache(maxsize=3600)
22
- # def classify_sample_location_cached(accession):
23
- # return classify_sample_location(accession)
24
-
25
- #@lru_cache(maxsize=3600)
26
- def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None):
27
- print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
- print("len of save df: ", len(save_df))
29
- return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df)
30
-
31
- # Count and suggest final location
32
- # def compute_final_suggested_location(rows):
33
- # candidates = [
34
- # row.get("Predicted Location", "").strip()
35
- # for row in rows
36
- # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
37
- # ] + [
38
- # row.get("Inferred Region", "").strip()
39
- # for row in rows
40
- # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
41
- # ]
42
-
43
- # if not candidates:
44
- # return Counter(), ("Unknown", 0)
45
- # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
46
- # tokens = []
47
- # for item in candidates:
48
- # # Split by comma, whitespace, and newlines
49
- # parts = re.split(r'[\s,]+', item)
50
- # tokens.extend(parts)
51
-
52
- # # Step 2: Clean and normalize tokens
53
- # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
54
-
55
- # # Step 3: Count
56
- # counts = Counter(tokens)
57
-
58
- # # Step 4: Get most common
59
- # top_location, count = counts.most_common(1)[0]
60
- # return counts, (top_location, count)
61
-
62
- # Store feedback (with required fields)
63
-
64
- def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
65
- if not answer1.strip() or not answer2.strip():
66
- return "⚠️ Please answer both questions before submitting."
67
-
68
- try:
69
- # βœ… Step: Load credentials from Hugging Face secret
70
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
71
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
72
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
73
-
74
- # Connect to Google Sheet
75
- client = gspread.authorize(creds)
76
- sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
77
-
78
- # Append feedback
79
- sheet.append_row([accession, answer1, answer2, contact])
80
- return "βœ… Feedback submitted. Thank you!"
81
-
82
- except Exception as e:
83
- return f"❌ Error submitting feedback: {e}"
84
-
85
- import re
86
-
87
- ACCESSION_REGEX = re.compile(r'^[A-Z]{1,4}_?\d{6}(\.\d+)?$')
88
-
89
- def is_valid_accession(acc):
90
- return bool(ACCESSION_REGEX.match(acc))
91
-
92
- # helper function to extract accessions
93
- def extract_accessions_from_input(file=None, raw_text=""):
94
- print(f"RAW TEXT RECEIVED: {raw_text}")
95
- accessions, invalid_accessions = [], []
96
- seen = set()
97
- if file:
98
- try:
99
- if file.name.endswith(".csv"):
100
- df = pd.read_csv(file)
101
- elif file.name.endswith(".xlsx"):
102
- df = pd.read_excel(file)
103
- else:
104
- return [], "Unsupported file format. Please upload CSV or Excel."
105
- for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
106
- if acc not in seen:
107
- if is_valid_accession(acc):
108
- accessions.append(acc)
109
- seen.add(acc)
110
- else:
111
- invalid_accessions.append(acc)
112
-
113
- except Exception as e:
114
- return [],[], f"Failed to read file: {e}"
115
-
116
- if raw_text:
117
- try:
118
- text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
119
- for acc in text_ids:
120
- if acc not in seen:
121
- if is_valid_accession(acc):
122
- accessions.append(acc)
123
- seen.add(acc)
124
- else:
125
- invalid_accessions.append(acc)
126
- except Exception as e:
127
- return [],[], f"Failed to read file: {e}"
128
-
129
- return list(accessions), list(invalid_accessions), None
130
- # βœ… Add a new helper to backend: `filter_unprocessed_accessions()`
131
- def get_incomplete_accessions(file_path):
132
- df = pd.read_excel(file_path)
133
-
134
- incomplete_accessions = []
135
- for _, row in df.iterrows():
136
- sample_id = str(row.get("Sample ID", "")).strip()
137
-
138
- # Skip if no sample ID
139
- if not sample_id:
140
- continue
141
-
142
- # Drop the Sample ID and check if the rest is empty
143
- other_cols = row.drop(labels=["Sample ID"], errors="ignore")
144
- if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
145
- # Extract the accession number from the sample ID using regex
146
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
147
- if match:
148
- incomplete_accessions.append(match.group(0))
149
- print(len(incomplete_accessions))
150
- return incomplete_accessions
151
-
152
- # GOOGLE_SHEET_NAME = "known_samples"
153
- # USAGE_DRIVE_FILENAME = "user_usage_log.json"
154
-
155
- def summarize_results(accession, stop_flag=None):
156
- # Early bail
157
- if stop_flag is not None and stop_flag.value:
158
- print(f"πŸ›‘ Skipping {accession} before starting.")
159
- return []
160
- # try cache first
161
- cached = check_known_output(accession)
162
- if cached:
163
- print(f"βœ… Using cached result for {accession}")
164
- return [[
165
- cached["Sample ID"] or "unknown",
166
- cached["Predicted Country"] or "unknown",
167
- cached["Country Explanation"] or "unknown",
168
- cached["Predicted Sample Type"] or "unknown",
169
- cached["Sample Type Explanation"] or "unknown",
170
- cached["Sources"] or "No Links",
171
- cached["Time cost"]
172
- ]]
173
- # only run when nothing in the cache
174
- try:
175
- print("try gemini pipeline: ",accession)
176
- # βœ… Load credentials from Hugging Face secret
177
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
178
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
179
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
180
- client = gspread.authorize(creds)
181
-
182
- spreadsheet = client.open("known_samples")
183
- sheet = spreadsheet.sheet1
184
-
185
- data = sheet.get_all_values()
186
- if not data:
187
- print("⚠️ Google Sheet 'known_samples' is empty.")
188
- return None
189
-
190
- save_df = pd.DataFrame(data[1:], columns=data[0])
191
- print("before pipeline, len of save df: ", len(save_df))
192
- outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df)
193
- if stop_flag is not None and stop_flag.value:
194
- print(f"πŸ›‘ Skipped {accession} mid-pipeline.")
195
- return []
196
- # outputs = {'KU131308': {'isolate':'BRU18',
197
- # 'country': {'brunei': ['ncbi',
198
- # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
199
- # 'sample_type': {'modern':
200
- # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
201
- # 'query_cost': 9.754999999999999e-05,
202
- # 'time_cost': '24.776 seconds',
203
- # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
204
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
205
- # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
206
- except Exception as e:
207
- return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
208
-
209
- if accession not in outputs:
210
- print("no accession in output ", accession)
211
- return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
212
-
213
- row_score = []
214
- rows = []
215
- save_rows = []
216
- for key in outputs:
217
- pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
218
- for section, results in outputs[key].items():
219
- if section == "country" or section =="sample_type":
220
- pred_output = []#"\n".join(list(results.keys()))
221
- output_explanation = ""
222
- for result, content in results.items():
223
- if len(result) == 0: result = "unknown"
224
- if len(content) == 0: output_explanation = "unknown"
225
- else:
226
- output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
227
- pred_output.append(result)
228
- pred_output = "\n".join(pred_output)
229
- if section == "country":
230
- pred_country, country_explanation = pred_output, output_explanation
231
- elif section == "sample_type":
232
- pred_sample, sample_explanation = pred_output, output_explanation
233
- if outputs[key]["isolate"].lower()!="unknown":
234
- label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
235
- else: label = key
236
- if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
237
- row = {
238
- "Sample ID": label or "unknown",
239
- "Predicted Country": pred_country or "unknown",
240
- "Country Explanation": country_explanation or "unknown",
241
- "Predicted Sample Type":pred_sample or "unknown",
242
- "Sample Type Explanation":sample_explanation or "unknown",
243
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
244
- "Time cost": outputs[key]["time_cost"]
245
- }
246
- #row_score.append(row)
247
- rows.append(list(row.values()))
248
-
249
- save_row = {
250
- "Sample ID": label or "unknown",
251
- "Predicted Country": pred_country or "unknown",
252
- "Country Explanation": country_explanation or "unknown",
253
- "Predicted Sample Type":pred_sample or "unknown",
254
- "Sample Type Explanation":sample_explanation or "unknown",
255
- "Sources": "\n".join(outputs[key]["source"]) or "No Links",
256
- "Query_cost": outputs[key]["query_cost"] or "",
257
- "Time cost": outputs[key]["time_cost"] or "",
258
- "file_chunk":outputs[key]["file_chunk"] or "",
259
- "file_all_output":outputs[key]["file_all_output"] or ""
260
- }
261
- #row_score.append(row)
262
- save_rows.append(list(save_row.values()))
263
-
264
- # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
265
- # summary_lines = [f"### 🧭 Location Summary:\n"]
266
- # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
267
- # summary_lines.append(f"\n**Final Suggested Location:** πŸ—ΊοΈ **{final_location}** (mentioned {count} times)")
268
- # summary = "\n".join(summary_lines)
269
-
270
- # save the new running sample to known excel file
271
- # try:
272
- # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
273
- # if os.path.exists(KNOWN_OUTPUT_PATH):
274
- # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
275
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
276
- # else:
277
- # df_combined = df_new
278
- # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
279
- # except Exception as e:
280
- # print(f"⚠️ Failed to save known output: {e}")
281
- # try:
282
- # df_new = pd.DataFrame(save_rows, columns=[
283
- # "Sample ID", "Predicted Country", "Country Explanation",
284
- # "Predicted Sample Type", "Sample Type Explanation",
285
- # "Sources", "Query_cost", "Time cost"
286
- # ])
287
-
288
- # # βœ… Google Sheets API setup
289
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
290
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
291
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
292
- # client = gspread.authorize(creds)
293
-
294
- # # βœ… Open the known_samples sheet
295
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
296
- # sheet = spreadsheet.sheet1
297
-
298
- # # βœ… Read old data
299
- # existing_data = sheet.get_all_values()
300
- # if existing_data:
301
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
302
- # else:
303
- # df_old = pd.DataFrame(columns=df_new.columns)
304
-
305
- # # βœ… Combine and remove duplicates
306
- # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
307
-
308
- # # βœ… Clear and write back
309
- # sheet.clear()
310
- # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
311
-
312
- # except Exception as e:
313
- # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
314
- try:
315
- # Prepare as DataFrame
316
- df_new = pd.DataFrame(save_rows, columns=[
317
- "Sample ID", "Predicted Country", "Country Explanation",
318
- "Predicted Sample Type", "Sample Type Explanation",
319
- "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
320
- ])
321
-
322
- # βœ… Setup Google Sheets
323
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
324
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
325
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
326
- client = gspread.authorize(creds)
327
- spreadsheet = client.open("known_samples")
328
- sheet = spreadsheet.sheet1
329
-
330
- # βœ… Read existing data
331
- existing_data = sheet.get_all_values()
332
-
333
- if existing_data:
334
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
335
-
336
- else:
337
-
338
- df_old = pd.DataFrame(columns=[
339
- "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
340
- "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
341
- "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output"
342
- ])
343
-
344
-
345
- # βœ… Index by Sample ID
346
- df_old.set_index("Sample ID", inplace=True)
347
- df_new.set_index("Sample ID", inplace=True)
348
-
349
- # βœ… Update only matching fields
350
- update_columns = [
351
- "Predicted Country", "Predicted Sample Type", "Country Explanation",
352
- "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
353
- ]
354
- for idx, row in df_new.iterrows():
355
- if idx not in df_old.index:
356
- df_old.loc[idx] = "" # new row, fill empty first
357
- for col in update_columns:
358
- if pd.notna(row[col]) and row[col] != "":
359
- df_old.at[idx, col] = row[col]
360
-
361
- # βœ… Reset and write back
362
- df_old.reset_index(inplace=True)
363
- sheet.clear()
364
- sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
365
- print("βœ… Match results saved to known_samples.")
366
-
367
- except Exception as e:
368
- print(f"❌ Failed to update known_samples: {e}")
369
-
370
-
371
- return rows#, summary, labelAncient_Modern, explain_label
372
-
373
- # save the batch input in excel file
374
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
375
- # with pd.ExcelWriter(filename) as writer:
376
- # # Save table
377
- # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
378
- # df.to_excel(writer, sheet_name="Detailed Results", index=False)
379
- # try:
380
- # df_old = pd.read_excel(filename)
381
- # except:
382
- # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
383
- # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
384
- # # if os.path.exists(filename):
385
- # # df_old = pd.read_excel(filename)
386
- # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
387
- # # else:
388
- # # df_combined = df_new
389
- # df_combined.to_excel(filename, index=False)
390
- # # # Save summary
391
- # # summary_df = pd.DataFrame({"Summary": [summary_text]})
392
- # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
393
-
394
- # # # Save flag
395
- # # flag_df = pd.DataFrame({"Flag": [flag_text]})
396
- # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
397
- # def save_to_excel(all_rows, summary_text, flag_text, filename):
398
- # df_new = pd.DataFrame(all_rows, columns=[
399
- # "Sample ID", "Predicted Country", "Country Explanation",
400
- # "Predicted Sample Type", "Sample Type Explanation",
401
- # "Sources", "Time cost"
402
- # ])
403
-
404
- # try:
405
- # if os.path.exists(filename):
406
- # df_old = pd.read_excel(filename)
407
- # else:
408
- # df_old = pd.DataFrame(columns=df_new.columns)
409
- # except Exception as e:
410
- # print(f"⚠️ Warning reading old Excel file: {e}")
411
- # df_old = pd.DataFrame(columns=df_new.columns)
412
-
413
- # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
414
- # df_old.set_index("Sample ID", inplace=True)
415
- # df_new.set_index("Sample ID", inplace=True)
416
-
417
- # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
418
-
419
- # df_combined = df_old.reset_index()
420
-
421
- # try:
422
- # df_combined.to_excel(filename, index=False)
423
- # except Exception as e:
424
- # print(f"❌ Failed to write Excel file {filename}: {e}")
425
- def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
426
- df_new = pd.DataFrame(all_rows, columns=[
427
- "Sample ID", "Predicted Country", "Country Explanation",
428
- "Predicted Sample Type", "Sample Type Explanation",
429
- "Sources", "Time cost"
430
- ])
431
-
432
- if is_resume and os.path.exists(filename):
433
- try:
434
- df_old = pd.read_excel(filename)
435
- except Exception as e:
436
- print(f"⚠️ Warning reading old Excel file: {e}")
437
- df_old = pd.DataFrame(columns=df_new.columns)
438
-
439
- # Set index and update existing rows
440
- df_old.set_index("Sample ID", inplace=True)
441
- df_new.set_index("Sample ID", inplace=True)
442
- df_old.update(df_new)
443
-
444
- df_combined = df_old.reset_index()
445
- else:
446
- # If not resuming or file doesn't exist, just use new rows
447
- df_combined = df_new
448
-
449
- try:
450
- df_combined.to_excel(filename, index=False)
451
- except Exception as e:
452
- print(f"❌ Failed to write Excel file {filename}: {e}")
453
-
454
-
455
- # save the batch input in JSON file
456
- def save_to_json(all_rows, summary_text, flag_text, filename):
457
- output_dict = {
458
- "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
459
- # "Summary_Text": summary_text,
460
- # "Ancient_Modern_Flag": flag_text
461
- }
462
-
463
- # If all_rows is a DataFrame, convert it
464
- if isinstance(all_rows, pd.DataFrame):
465
- output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
466
-
467
- with open(filename, "w") as external_file:
468
- json.dump(output_dict, external_file, indent=2)
469
-
470
- # save the batch input in Text file
471
- def save_to_txt(all_rows, summary_text, flag_text, filename):
472
- if isinstance(all_rows, pd.DataFrame):
473
- detailed_results = all_rows.to_dict(orient="records")
474
- output = ""
475
- #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
476
- output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
477
- for r in detailed_results:
478
- output += ",".join([str(v) for v in r.values()]) + "\n\n"
479
- with open(filename, "w") as f:
480
- f.write("=== Detailed Results ===\n")
481
- f.write(output + "\n")
482
-
483
- # f.write("\n=== Summary ===\n")
484
- # f.write(summary_text + "\n")
485
-
486
- # f.write("\n=== Ancient/Modern Flag ===\n")
487
- # f.write(flag_text + "\n")
488
-
489
- def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
490
- tmp_dir = tempfile.mkdtemp()
491
-
492
- #html_table = all_rows.value # assuming this is stored somewhere
493
-
494
- # Parse back to DataFrame
495
- #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
496
- all_rows = pd.read_html(StringIO(all_rows))[0]
497
- print(all_rows)
498
-
499
- if output_type == "Excel":
500
- file_path = f"{tmp_dir}/batch_output.xlsx"
501
- save_to_excel(all_rows, summary_text, flag_text, file_path)
502
- elif output_type == "JSON":
503
- file_path = f"{tmp_dir}/batch_output.json"
504
- save_to_json(all_rows, summary_text, flag_text, file_path)
505
- print("Done with JSON")
506
- elif output_type == "TXT":
507
- file_path = f"{tmp_dir}/batch_output.txt"
508
- save_to_txt(all_rows, summary_text, flag_text, file_path)
509
- else:
510
- return gr.update(visible=False) # invalid option
511
-
512
- return gr.update(value=file_path, visible=True)
513
- # save cost by checking the known outputs
514
-
515
- # def check_known_output(accession):
516
- # if not os.path.exists(KNOWN_OUTPUT_PATH):
517
- # return None
518
-
519
- # try:
520
- # df = pd.read_excel(KNOWN_OUTPUT_PATH)
521
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
522
- # if match:
523
- # accession = match.group(0)
524
-
525
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
526
- # if not matched.empty:
527
- # return matched.iloc[0].to_dict() # Return the cached row
528
- # except Exception as e:
529
- # print(f"⚠️ Failed to load known samples: {e}")
530
- # return None
531
-
532
- # def check_known_output(accession):
533
- # try:
534
- # # βœ… Load credentials from Hugging Face secret
535
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
536
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
537
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
538
- # client = gspread.authorize(creds)
539
-
540
- # # βœ… Open the known_samples sheet
541
- # spreadsheet = client.open("known_samples") # Replace with your sheet name
542
- # sheet = spreadsheet.sheet1
543
-
544
- # # βœ… Read all rows
545
- # data = sheet.get_all_values()
546
- # if not data:
547
- # return None
548
-
549
- # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
550
-
551
- # # βœ… Normalize accession pattern
552
- # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
553
- # if match:
554
- # accession = match.group(0)
555
-
556
- # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
557
- # if not matched.empty:
558
- # return matched.iloc[0].to_dict()
559
-
560
- # except Exception as e:
561
- # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
562
- # return None
563
- def check_known_output(accession):
564
- print("inside check known output function")
565
- try:
566
- # βœ… Load credentials from Hugging Face secret
567
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
568
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
569
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
570
- client = gspread.authorize(creds)
571
-
572
- spreadsheet = client.open("known_samples")
573
- sheet = spreadsheet.sheet1
574
-
575
- data = sheet.get_all_values()
576
- if not data:
577
- print("⚠️ Google Sheet 'known_samples' is empty.")
578
- return None
579
-
580
- df = pd.DataFrame(data[1:], columns=data[0])
581
- if "Sample ID" not in df.columns:
582
- print("❌ Column 'Sample ID' not found in Google Sheet.")
583
- return None
584
-
585
- match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
586
- if match:
587
- accession = match.group(0)
588
-
589
- matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
590
- if not matched.empty:
591
- #return matched.iloc[0].to_dict()
592
- row = matched.iloc[0]
593
- country = row.get("Predicted Country", "").strip().lower()
594
- sample_type = row.get("Predicted Sample Type", "").strip().lower()
595
-
596
- if country and country != "unknown" and sample_type and sample_type != "unknown":
597
- return row.to_dict()
598
- else:
599
- print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
600
- return None
601
- else:
602
- print(f"πŸ” Accession {accession} not found in known_samples.")
603
- return None
604
-
605
- except Exception as e:
606
- import traceback
607
- print("❌ Exception occurred during check_known_output:")
608
- traceback.print_exc()
609
- return None
610
-
611
-
612
- def hash_user_id(user_input):
613
- return hashlib.sha256(user_input.encode()).hexdigest()
614
-
615
- # βœ… Load and save usage count
616
-
617
- # def load_user_usage():
618
- # if not os.path.exists(USER_USAGE_TRACK_FILE):
619
- # return {}
620
-
621
- # try:
622
- # with open(USER_USAGE_TRACK_FILE, "r") as f:
623
- # content = f.read().strip()
624
- # if not content:
625
- # return {} # file is empty
626
- # return json.loads(content)
627
- # except (json.JSONDecodeError, ValueError):
628
- # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
629
- # return {} # fallback to empty dict
630
- # def load_user_usage():
631
- # try:
632
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
633
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
634
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
635
- # client = gspread.authorize(creds)
636
-
637
- # sheet = client.open("user_usage_log").sheet1
638
- # data = sheet.get_all_records() # Assumes columns: email, usage_count
639
-
640
- # usage = {}
641
- # for row in data:
642
- # email = row.get("email", "").strip().lower()
643
- # count = int(row.get("usage_count", 0))
644
- # if email:
645
- # usage[email] = count
646
- # return usage
647
- # except Exception as e:
648
- # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
649
- # return {}
650
- # def load_user_usage():
651
- # try:
652
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
653
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
654
-
655
- # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
656
- # if not found:
657
- # return {} # not found, start fresh
658
-
659
- # #file_id = found[0]["id"]
660
- # file_id = found
661
- # content = pipeline.download_drive_file_content(file_id)
662
- # return json.loads(content.strip()) if content.strip() else {}
663
-
664
- # except Exception as e:
665
- # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
666
- # return {}
667
- def load_user_usage():
668
- try:
669
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
670
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
671
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
672
- client = gspread.authorize(creds)
673
-
674
- sheet = client.open("user_usage_log").sheet1
675
- data = sheet.get_all_values()
676
- print("data: ", data)
677
- print("πŸ§ͺ Raw header row from sheet:", data[0])
678
- print("πŸ§ͺ Character codes in each header:")
679
- for h in data[0]:
680
- print([ord(c) for c in h])
681
-
682
- if not data or len(data) < 2:
683
- print("⚠️ Sheet is empty or missing rows.")
684
- return {}
685
-
686
- headers = [h.strip().lower() for h in data[0]]
687
- if "email" not in headers or "usage_count" not in headers:
688
- print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
689
- return {}
690
-
691
- permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
692
- df = pd.DataFrame(data[1:], columns=headers)
693
-
694
- usage = {}
695
- permitted = {}
696
- for _, row in df.iterrows():
697
- email = row.get("email", "").strip().lower()
698
- try:
699
- #count = int(row.get("usage_count", 0))
700
- try:
701
- count = int(float(row.get("usage_count", 0)))
702
- except Exception:
703
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
704
- count = 0
705
-
706
- if email:
707
- usage[email] = count
708
- if permitted_index is not None:
709
- try:
710
- permitted_count = int(float(row.get("permitted_samples", 50)))
711
- permitted[email] = permitted_count
712
- except:
713
- permitted[email] = 50
714
-
715
- except ValueError:
716
- print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
717
- return usage, permitted
718
-
719
- except Exception as e:
720
- print(f"❌ Error in load_user_usage: {e}")
721
- return {}, {}
722
-
723
-
724
-
725
- # def save_user_usage(usage):
726
- # with open(USER_USAGE_TRACK_FILE, "w") as f:
727
- # json.dump(usage, f, indent=2)
728
-
729
- # def save_user_usage(usage_dict):
730
- # try:
731
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
732
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
733
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
734
- # client = gspread.authorize(creds)
735
-
736
- # sheet = client.open("user_usage_log").sheet1
737
- # sheet.clear() # clear old contents first
738
-
739
- # # Write header + rows
740
- # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
741
- # sheet.update(rows)
742
- # except Exception as e:
743
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
744
- # def save_user_usage(usage_dict):
745
- # try:
746
- # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
747
- # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
748
-
749
- # import tempfile
750
- # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
751
- # print("πŸ’Ύ Saving this usage dict:", usage_dict)
752
- # with open(tmp_path, "w") as f:
753
- # json.dump(usage_dict, f, indent=2)
754
-
755
- # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
756
-
757
- # except Exception as e:
758
- # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
759
- # def save_user_usage(usage_dict):
760
- # try:
761
- # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
762
- # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
763
- # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
764
- # client = gspread.authorize(creds)
765
-
766
- # spreadsheet = client.open("user_usage_log")
767
- # sheet = spreadsheet.sheet1
768
-
769
- # # Step 1: Convert new usage to DataFrame
770
- # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
771
- # df_new["email"] = df_new["email"].str.strip().str.lower()
772
-
773
- # # Step 2: Load existing data
774
- # existing_data = sheet.get_all_values()
775
- # print("πŸ§ͺ Sheet existing_data:", existing_data)
776
-
777
- # # Try to load old data
778
- # if existing_data and len(existing_data[0]) >= 1:
779
- # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
780
-
781
- # # Fix missing columns
782
- # if "email" not in df_old.columns:
783
- # df_old["email"] = ""
784
- # if "usage_count" not in df_old.columns:
785
- # df_old["usage_count"] = 0
786
-
787
- # df_old["email"] = df_old["email"].str.strip().str.lower()
788
- # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
789
- # else:
790
- # df_old = pd.DataFrame(columns=["email", "usage_count"])
791
-
792
- # # Step 3: Merge
793
- # df_combined = pd.concat([df_old, df_new], ignore_index=True)
794
- # df_combined = df_combined.groupby("email", as_index=False).sum()
795
-
796
- # # Step 4: Write back
797
- # sheet.clear()
798
- # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
799
- # print("βœ… Saved user usage to user_usage_log sheet.")
800
-
801
- # except Exception as e:
802
- # print(f"❌ Failed to save user usage to Google Sheets: {e}")
803
- def save_user_usage(usage_dict):
804
- try:
805
- creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
806
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
807
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
808
- client = gspread.authorize(creds)
809
-
810
- spreadsheet = client.open("user_usage_log")
811
- sheet = spreadsheet.sheet1
812
-
813
- # Build new df
814
- df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
815
- df_new["email"] = df_new["email"].str.strip().str.lower()
816
- df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
817
-
818
- # Read existing data
819
- existing_data = sheet.get_all_values()
820
- if existing_data and len(existing_data[0]) >= 2:
821
- df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
822
- df_old["email"] = df_old["email"].str.strip().str.lower()
823
- df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
824
- else:
825
- df_old = pd.DataFrame(columns=["email", "usage_count"])
826
-
827
- # βœ… Overwrite specific emails only
828
- df_old = df_old.set_index("email")
829
- for email, count in usage_dict.items():
830
- email = email.strip().lower()
831
- df_old.loc[email, "usage_count"] = count
832
- df_old = df_old.reset_index()
833
-
834
- # Save
835
- sheet.clear()
836
- sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
837
- print("βœ… Saved user usage to user_usage_log sheet.")
838
-
839
- except Exception as e:
840
- print(f"❌ Failed to save user usage to Google Sheets: {e}")
841
-
842
-
843
-
844
-
845
- # def increment_usage(user_id, num_samples=1):
846
- # usage = load_user_usage()
847
- # if user_id not in usage:
848
- # usage[user_id] = 0
849
- # usage[user_id] += num_samples
850
- # save_user_usage(usage)
851
- # return usage[user_id]
852
- # def increment_usage(email: str, count: int):
853
- # usage = load_user_usage()
854
- # email_key = email.strip().lower()
855
- # usage[email_key] = usage.get(email_key, 0) + count
856
- # save_user_usage(usage)
857
- # return usage[email_key]
858
- def increment_usage(email: str, count: int = 1):
859
- usage, permitted = load_user_usage()
860
- email_key = email.strip().lower()
861
- #usage[email_key] = usage.get(email_key, 0) + count
862
- current = usage.get(email_key, 0)
863
- new_value = current + count
864
- max_allowed = permitted.get(email_key) or 50
865
- usage[email_key] = max(current, new_value) # βœ… Prevent overwrite with lower
866
- print(f"πŸ§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
867
- print("max allow is: ", max_allowed)
868
- save_user_usage(usage)
869
- return usage[email_key], max_allowed
870
-
871
-
872
- # run the batch
873
- def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
874
- stop_flag=None, output_file_path=None,
875
- limited_acc=50, yield_callback=None):
876
- if user_email:
877
- limited_acc += 10
878
- accessions, error = extract_accessions_from_input(file, raw_text)
879
- if error:
880
- #return [], "", "", f"Error: {error}"
881
- return [], f"Error: {error}", 0, "", ""
882
- if resume_file:
883
- accessions = get_incomplete_accessions(resume_file)
884
- tmp_dir = tempfile.mkdtemp()
885
- if not output_file_path:
886
- if resume_file:
887
- output_file_path = os.path.join(tmp_dir, resume_file)
888
- else:
889
- output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
890
-
891
- all_rows = []
892
- # all_summaries = []
893
- # all_flags = []
894
- progress_lines = []
895
- warning = ""
896
- if len(accessions) > limited_acc:
897
- accessions = accessions[:limited_acc]
898
- warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
899
- for i, acc in enumerate(accessions):
900
- if stop_flag and stop_flag.value:
901
- line = f"πŸ›‘ Stopped at {acc} ({i+1}/{len(accessions)})"
902
- progress_lines.append(line)
903
- if yield_callback:
904
- yield_callback(line)
905
- print("πŸ›‘ User requested stop.")
906
- break
907
- print(f"[{i+1}/{len(accessions)}] Processing {acc}")
908
- try:
909
- # rows, summary, label, explain = summarize_results(acc)
910
- rows = summarize_results(acc)
911
- all_rows.extend(rows)
912
- # all_summaries.append(f"**{acc}**\n{summary}")
913
- # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
914
- #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
915
- save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
916
- line = f"βœ… Processed {acc} ({i+1}/{len(accessions)})"
917
- progress_lines.append(line)
918
- if yield_callback:
919
- yield_callback(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
920
- except Exception as e:
921
- print(f"❌ Failed to process {acc}: {e}")
922
- continue
923
- #all_summaries.append(f"**{acc}**: Failed - {e}")
924
- #progress_lines.append(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
925
- limited_acc -= 1
926
- """for row in all_rows:
927
- source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
928
-
929
- if source_column.startswith("http"): # Check if the source is a URL
930
- # Wrap it with HTML anchor tags to make it clickable
931
- row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
932
- if not warning:
933
- warning = f"You only have {limited_acc} left"
934
- if user_email.strip():
935
- user_hash = hash_user_id(user_email)
936
- total_queries = increment_usage(user_hash, len(all_rows))
937
- else:
938
- total_queries = 0
939
- yield_callback("βœ… Finished!")
940
-
941
- # summary_text = "\n\n---\n\n".join(all_summaries)
942
- # flag_text = "\n\n---\n\n".join(all_flags)
943
- #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
944
- #return all_rows, gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning
 
1
+ import gradio as gr
2
+ from collections import Counter
3
+ import csv
4
+ import os
5
+ from functools import lru_cache
6
+ #import app
7
+ from mtdna_classifier import classify_sample_location
8
+ import data_preprocess, model, pipeline
9
+ import subprocess
10
+ import json
11
+ import pandas as pd
12
+ import io
13
+ import re
14
+ import tempfile
15
+ import gspread
16
+ from oauth2client.service_account import ServiceAccountCredentials
17
+ from io import StringIO
18
+ import hashlib
19
+ import threading
20
+
21
+ # @lru_cache(maxsize=3600)
22
+ # def classify_sample_location_cached(accession):
23
+ # return classify_sample_location(accession)
24
+
25
+ #@lru_cache(maxsize=3600)
26
+ def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None):
27
+ print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
28
+ print("len of save df: ", len(save_df))
29
+ return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df)
30
+
31
+ # Count and suggest final location
32
+ # def compute_final_suggested_location(rows):
33
+ # candidates = [
34
+ # row.get("Predicted Location", "").strip()
35
+ # for row in rows
36
+ # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
37
+ # ] + [
38
+ # row.get("Inferred Region", "").strip()
39
+ # for row in rows
40
+ # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
41
+ # ]
42
+
43
+ # if not candidates:
44
+ # return Counter(), ("Unknown", 0)
45
+ # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
46
+ # tokens = []
47
+ # for item in candidates:
48
+ # # Split by comma, whitespace, and newlines
49
+ # parts = re.split(r'[\s,]+', item)
50
+ # tokens.extend(parts)
51
+
52
+ # # Step 2: Clean and normalize tokens
53
+ # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
54
+
55
+ # # Step 3: Count
56
+ # counts = Counter(tokens)
57
+
58
+ # # Step 4: Get most common
59
+ # top_location, count = counts.most_common(1)[0]
60
+ # return counts, (top_location, count)
61
+
62
+ # Store feedback (with required fields)
63
+
64
+ def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
65
+ if not answer1.strip() or not answer2.strip():
66
+ return "⚠️ Please answer both questions before submitting."
67
+
68
+ try:
69
+ # βœ… Step: Load credentials from Hugging Face secret
70
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
71
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
72
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
73
+
74
+ # Connect to Google Sheet
75
+ client = gspread.authorize(creds)
76
+ sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
77
+
78
+ # Append feedback
79
+ sheet.append_row([accession, answer1, answer2, contact])
80
+ return "βœ… Feedback submitted. Thank you!"
81
+
82
+ except Exception as e:
83
+ return f"❌ Error submitting feedback: {e}"
84
+
85
+ import re
86
+
87
+ ACCESSION_REGEX = re.compile(r'^[A-Z]{1,4}_?\d{6}(\.\d+)?$')
88
+
89
+ def is_valid_accession(acc):
90
+ return bool(ACCESSION_REGEX.match(acc))
91
+
92
+ # helper function to extract accessions
93
+ def extract_accessions_from_input(file=None, raw_text=""):
94
+ print(f"RAW TEXT RECEIVED: {raw_text}")
95
+ accessions, invalid_accessions = [], []
96
+ seen = set()
97
+ if file:
98
+ try:
99
+ if file.name.endswith(".csv"):
100
+ df = pd.read_csv(file)
101
+ elif file.name.endswith(".xlsx"):
102
+ df = pd.read_excel(file)
103
+ else:
104
+ return [], "Unsupported file format. Please upload CSV or Excel."
105
+ for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
106
+ if acc not in seen:
107
+ if is_valid_accession(acc):
108
+ accessions.append(acc)
109
+ seen.add(acc)
110
+ else:
111
+ invalid_accessions.append(acc)
112
+
113
+ except Exception as e:
114
+ return [],[], f"Failed to read file: {e}"
115
+
116
+ if raw_text:
117
+ try:
118
+ text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
119
+ for acc in text_ids:
120
+ if acc not in seen:
121
+ if is_valid_accession(acc):
122
+ accessions.append(acc)
123
+ seen.add(acc)
124
+ else:
125
+ invalid_accessions.append(acc)
126
+ except Exception as e:
127
+ return [],[], f"Failed to read file: {e}"
128
+
129
+ return list(accessions), list(invalid_accessions), None
130
+ # βœ… Add a new helper to backend: `filter_unprocessed_accessions()`
131
+ def get_incomplete_accessions(file_path):
132
+ df = pd.read_excel(file_path)
133
+
134
+ incomplete_accessions = []
135
+ for _, row in df.iterrows():
136
+ sample_id = str(row.get("Sample ID", "")).strip()
137
+
138
+ # Skip if no sample ID
139
+ if not sample_id:
140
+ continue
141
+
142
+ # Drop the Sample ID and check if the rest is empty
143
+ other_cols = row.drop(labels=["Sample ID"], errors="ignore")
144
+ if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
145
+ # Extract the accession number from the sample ID using regex
146
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
147
+ if match:
148
+ incomplete_accessions.append(match.group(0))
149
+ print(len(incomplete_accessions))
150
+ return incomplete_accessions
151
+
152
+ # GOOGLE_SHEET_NAME = "known_samples"
153
+ # USAGE_DRIVE_FILENAME = "user_usage_log.json"
154
+
155
+ def summarize_results(accession, stop_flag=None):
156
+ # Early bail
157
+ if stop_flag is not None and stop_flag.value:
158
+ print(f"πŸ›‘ Skipping {accession} before starting.")
159
+ return []
160
+ # try cache first
161
+ cached = check_known_output(accession)
162
+ if cached:
163
+ print(f"βœ… Using cached result for {accession}")
164
+ return [[
165
+ cached["Sample ID"] or "unknown",
166
+ cached["Predicted Country"] or "unknown",
167
+ cached["Country Explanation"] or "unknown",
168
+ cached["Predicted Sample Type"] or "unknown",
169
+ cached["Sample Type Explanation"] or "unknown",
170
+ cached["Sources"] or "No Links",
171
+ cached["Time cost"]
172
+ ]]
173
+ # only run when nothing in the cache
174
+ try:
175
+ print("try gemini pipeline: ",accession)
176
+ # βœ… Load credentials from Hugging Face secret
177
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
178
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
179
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
180
+ client = gspread.authorize(creds)
181
+
182
+ spreadsheet = client.open("known_samples")
183
+ sheet = spreadsheet.sheet1
184
+
185
+ data = sheet.get_all_values()
186
+ if not data:
187
+ print("⚠️ Google Sheet 'known_samples' is empty.")
188
+ return None
189
+
190
+ save_df = pd.DataFrame(data[1:], columns=data[0])
191
+ print("before pipeline, len of save df: ", len(save_df))
192
+ outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df)
193
+ if stop_flag is not None and stop_flag.value:
194
+ print(f"πŸ›‘ Skipped {accession} mid-pipeline.")
195
+ return []
196
+ # outputs = {'KU131308': {'isolate':'BRU18',
197
+ # 'country': {'brunei': ['ncbi',
198
+ # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
199
+ # 'sample_type': {'modern':
200
+ # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
201
+ # 'query_cost': 9.754999999999999e-05,
202
+ # 'time_cost': '24.776 seconds',
203
+ # 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
204
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
205
+ # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
206
+ except Exception as e:
207
+ return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
208
+
209
+ if accession not in outputs:
210
+ print("no accession in output ", accession)
211
+ return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
212
+
213
+ row_score = []
214
+ rows = []
215
+ save_rows = []
216
+ for key in outputs:
217
+ pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
218
+ for section, results in outputs[key].items():
219
+ if section == "country" or section =="sample_type":
220
+ pred_output = []#"\n".join(list(results.keys()))
221
+ output_explanation = ""
222
+ for result, content in results.items():
223
+ if len(result) == 0: result = "unknown"
224
+ if len(content) == 0: output_explanation = "unknown"
225
+ else:
226
+ output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
227
+ pred_output.append(result)
228
+ pred_output = "\n".join(pred_output)
229
+ if section == "country":
230
+ pred_country, country_explanation = pred_output, output_explanation
231
+ elif section == "sample_type":
232
+ pred_sample, sample_explanation = pred_output, output_explanation
233
+ if outputs[key]["isolate"].lower()!="unknown":
234
+ label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
235
+ else: label = key
236
+ if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
237
+ row = {
238
+ "Sample ID": label or "unknown",
239
+ "Predicted Country": pred_country or "unknown",
240
+ "Country Explanation": country_explanation or "unknown",
241
+ "Predicted Sample Type":pred_sample or "unknown",
242
+ "Sample Type Explanation":sample_explanation or "unknown",
243
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
244
+ "Time cost": outputs[key]["time_cost"]
245
+ }
246
+ #row_score.append(row)
247
+ rows.append(list(row.values()))
248
+
249
+ save_row = {
250
+ "Sample ID": label or "unknown",
251
+ "Predicted Country": pred_country or "unknown",
252
+ "Country Explanation": country_explanation or "unknown",
253
+ "Predicted Sample Type":pred_sample or "unknown",
254
+ "Sample Type Explanation":sample_explanation or "unknown",
255
+ "Sources": "\n".join(outputs[key]["source"]) or "No Links",
256
+ "Query_cost": outputs[key]["query_cost"] or "",
257
+ "Time cost": outputs[key]["time_cost"] or "",
258
+ "file_chunk":outputs[key]["file_chunk"] or "",
259
+ "file_all_output":outputs[key]["file_all_output"] or ""
260
+ }
261
+ #row_score.append(row)
262
+ save_rows.append(list(save_row.values()))
263
+
264
+ # #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
265
+ # summary_lines = [f"### 🧭 Location Summary:\n"]
266
+ # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
267
+ # summary_lines.append(f"\n**Final Suggested Location:** πŸ—ΊοΈ **{final_location}** (mentioned {count} times)")
268
+ # summary = "\n".join(summary_lines)
269
+
270
+ # save the new running sample to known excel file
271
+ # try:
272
+ # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
273
+ # if os.path.exists(KNOWN_OUTPUT_PATH):
274
+ # df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
275
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
276
+ # else:
277
+ # df_combined = df_new
278
+ # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
279
+ # except Exception as e:
280
+ # print(f"⚠️ Failed to save known output: {e}")
281
+ # try:
282
+ # df_new = pd.DataFrame(save_rows, columns=[
283
+ # "Sample ID", "Predicted Country", "Country Explanation",
284
+ # "Predicted Sample Type", "Sample Type Explanation",
285
+ # "Sources", "Query_cost", "Time cost"
286
+ # ])
287
+
288
+ # # βœ… Google Sheets API setup
289
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
290
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
291
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
292
+ # client = gspread.authorize(creds)
293
+
294
+ # # βœ… Open the known_samples sheet
295
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
296
+ # sheet = spreadsheet.sheet1
297
+
298
+ # # βœ… Read old data
299
+ # existing_data = sheet.get_all_values()
300
+ # if existing_data:
301
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
302
+ # else:
303
+ # df_old = pd.DataFrame(columns=df_new.columns)
304
+
305
+ # # βœ… Combine and remove duplicates
306
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
307
+
308
+ # # βœ… Clear and write back
309
+ # sheet.clear()
310
+ # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
311
+
312
+ # except Exception as e:
313
+ # print(f"⚠️ Failed to save known output to Google Sheets: {e}")
314
+ try:
315
+ # Prepare as DataFrame
316
+ df_new = pd.DataFrame(save_rows, columns=[
317
+ "Sample ID", "Predicted Country", "Country Explanation",
318
+ "Predicted Sample Type", "Sample Type Explanation",
319
+ "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
320
+ ])
321
+
322
+ # βœ… Setup Google Sheets
323
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
324
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
325
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
326
+ client = gspread.authorize(creds)
327
+ spreadsheet = client.open("known_samples")
328
+ sheet = spreadsheet.sheet1
329
+
330
+ # βœ… Read existing data
331
+ existing_data = sheet.get_all_values()
332
+
333
+ if existing_data:
334
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
335
+
336
+ else:
337
+
338
+ df_old = pd.DataFrame(columns=[
339
+ "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
340
+ "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
341
+ "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output"
342
+ ])
343
+
344
+
345
+ # βœ… Index by Sample ID
346
+ df_old.set_index("Sample ID", inplace=True)
347
+ df_new.set_index("Sample ID", inplace=True)
348
+
349
+ # βœ… Update only matching fields
350
+ update_columns = [
351
+ "Predicted Country", "Predicted Sample Type", "Country Explanation",
352
+ "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
353
+ ]
354
+ for idx, row in df_new.iterrows():
355
+ if idx not in df_old.index:
356
+ df_old.loc[idx] = "" # new row, fill empty first
357
+ for col in update_columns:
358
+ if pd.notna(row[col]) and row[col] != "":
359
+ df_old.at[idx, col] = row[col]
360
+
361
+ # βœ… Reset and write back
362
+ df_old.reset_index(inplace=True)
363
+ sheet.clear()
364
+ sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
365
+ print("βœ… Match results saved to known_samples.")
366
+
367
+ except Exception as e:
368
+ print(f"❌ Failed to update known_samples: {e}")
369
+
370
+
371
+ return rows#, summary, labelAncient_Modern, explain_label
372
+
373
+ # save the batch input in excel file
374
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
375
+ # with pd.ExcelWriter(filename) as writer:
376
+ # # Save table
377
+ # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
378
+ # df.to_excel(writer, sheet_name="Detailed Results", index=False)
379
+ # try:
380
+ # df_old = pd.read_excel(filename)
381
+ # except:
382
+ # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
383
+ # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
384
+ # # if os.path.exists(filename):
385
+ # # df_old = pd.read_excel(filename)
386
+ # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
387
+ # # else:
388
+ # # df_combined = df_new
389
+ # df_combined.to_excel(filename, index=False)
390
+ # # # Save summary
391
+ # # summary_df = pd.DataFrame({"Summary": [summary_text]})
392
+ # # summary_df.to_excel(writer, sheet_name="Summary", index=False)
393
+
394
+ # # # Save flag
395
+ # # flag_df = pd.DataFrame({"Flag": [flag_text]})
396
+ # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
397
+ # def save_to_excel(all_rows, summary_text, flag_text, filename):
398
+ # df_new = pd.DataFrame(all_rows, columns=[
399
+ # "Sample ID", "Predicted Country", "Country Explanation",
400
+ # "Predicted Sample Type", "Sample Type Explanation",
401
+ # "Sources", "Time cost"
402
+ # ])
403
+
404
+ # try:
405
+ # if os.path.exists(filename):
406
+ # df_old = pd.read_excel(filename)
407
+ # else:
408
+ # df_old = pd.DataFrame(columns=df_new.columns)
409
+ # except Exception as e:
410
+ # print(f"⚠️ Warning reading old Excel file: {e}")
411
+ # df_old = pd.DataFrame(columns=df_new.columns)
412
+
413
+ # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
414
+ # df_old.set_index("Sample ID", inplace=True)
415
+ # df_new.set_index("Sample ID", inplace=True)
416
+
417
+ # df_old.update(df_new) # <-- update matching rows in df_old with df_new content
418
+
419
+ # df_combined = df_old.reset_index()
420
+
421
+ # try:
422
+ # df_combined.to_excel(filename, index=False)
423
+ # except Exception as e:
424
+ # print(f"❌ Failed to write Excel file {filename}: {e}")
425
+ def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
426
+ df_new = pd.DataFrame(all_rows, columns=[
427
+ "Sample ID", "Predicted Country", "Country Explanation",
428
+ "Predicted Sample Type", "Sample Type Explanation",
429
+ "Sources", "Time cost"
430
+ ])
431
+
432
+ if is_resume and os.path.exists(filename):
433
+ try:
434
+ df_old = pd.read_excel(filename)
435
+ except Exception as e:
436
+ print(f"⚠️ Warning reading old Excel file: {e}")
437
+ df_old = pd.DataFrame(columns=df_new.columns)
438
+
439
+ # Set index and update existing rows
440
+ df_old.set_index("Sample ID", inplace=True)
441
+ df_new.set_index("Sample ID", inplace=True)
442
+ df_old.update(df_new)
443
+
444
+ df_combined = df_old.reset_index()
445
+ else:
446
+ # If not resuming or file doesn't exist, just use new rows
447
+ df_combined = df_new
448
+
449
+ try:
450
+ df_combined.to_excel(filename, index=False)
451
+ except Exception as e:
452
+ print(f"❌ Failed to write Excel file {filename}: {e}")
453
+
454
+
455
+ # save the batch input in JSON file
456
+ def save_to_json(all_rows, summary_text, flag_text, filename):
457
+ output_dict = {
458
+ "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
459
+ # "Summary_Text": summary_text,
460
+ # "Ancient_Modern_Flag": flag_text
461
+ }
462
+
463
+ # If all_rows is a DataFrame, convert it
464
+ if isinstance(all_rows, pd.DataFrame):
465
+ output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
466
+
467
+ with open(filename, "w") as external_file:
468
+ json.dump(output_dict, external_file, indent=2)
469
+
470
+ # save the batch input in Text file
471
+ def save_to_txt(all_rows, summary_text, flag_text, filename):
472
+ if isinstance(all_rows, pd.DataFrame):
473
+ detailed_results = all_rows.to_dict(orient="records")
474
+ output = ""
475
+ #output += ",".join(list(detailed_results[0].keys())) + "\n\n"
476
+ output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
477
+ for r in detailed_results:
478
+ output += ",".join([str(v) for v in r.values()]) + "\n\n"
479
+ with open(filename, "w") as f:
480
+ f.write("=== Detailed Results ===\n")
481
+ f.write(output + "\n")
482
+
483
+ # f.write("\n=== Summary ===\n")
484
+ # f.write(summary_text + "\n")
485
+
486
+ # f.write("\n=== Ancient/Modern Flag ===\n")
487
+ # f.write(flag_text + "\n")
488
+
489
+ def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
490
+ tmp_dir = tempfile.mkdtemp()
491
+
492
+ #html_table = all_rows.value # assuming this is stored somewhere
493
+
494
+ # Parse back to DataFrame
495
+ #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
496
+ all_rows = pd.read_html(StringIO(all_rows))[0]
497
+ print(all_rows)
498
+
499
+ if output_type == "Excel":
500
+ file_path = f"{tmp_dir}/batch_output.xlsx"
501
+ save_to_excel(all_rows, summary_text, flag_text, file_path)
502
+ elif output_type == "JSON":
503
+ file_path = f"{tmp_dir}/batch_output.json"
504
+ save_to_json(all_rows, summary_text, flag_text, file_path)
505
+ print("Done with JSON")
506
+ elif output_type == "TXT":
507
+ file_path = f"{tmp_dir}/batch_output.txt"
508
+ save_to_txt(all_rows, summary_text, flag_text, file_path)
509
+ else:
510
+ return gr.update(visible=False) # invalid option
511
+
512
+ return gr.update(value=file_path, visible=True)
513
+ # save cost by checking the known outputs
514
+
515
+ # def check_known_output(accession):
516
+ # if not os.path.exists(KNOWN_OUTPUT_PATH):
517
+ # return None
518
+
519
+ # try:
520
+ # df = pd.read_excel(KNOWN_OUTPUT_PATH)
521
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
522
+ # if match:
523
+ # accession = match.group(0)
524
+
525
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
526
+ # if not matched.empty:
527
+ # return matched.iloc[0].to_dict() # Return the cached row
528
+ # except Exception as e:
529
+ # print(f"⚠️ Failed to load known samples: {e}")
530
+ # return None
531
+
532
+ # def check_known_output(accession):
533
+ # try:
534
+ # # βœ… Load credentials from Hugging Face secret
535
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
536
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
537
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
538
+ # client = gspread.authorize(creds)
539
+
540
+ # # βœ… Open the known_samples sheet
541
+ # spreadsheet = client.open("known_samples") # Replace with your sheet name
542
+ # sheet = spreadsheet.sheet1
543
+
544
+ # # βœ… Read all rows
545
+ # data = sheet.get_all_values()
546
+ # if not data:
547
+ # return None
548
+
549
+ # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
550
+
551
+ # # βœ… Normalize accession pattern
552
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
553
+ # if match:
554
+ # accession = match.group(0)
555
+
556
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
557
+ # if not matched.empty:
558
+ # return matched.iloc[0].to_dict()
559
+
560
+ # except Exception as e:
561
+ # print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
562
+ # return None
563
+ # def check_known_output(accession):
564
+ # print("inside check known output function")
565
+ # try:
566
+ # # βœ… Load credentials from Hugging Face secret
567
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
568
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
569
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
570
+ # client = gspread.authorize(creds)
571
+
572
+ # spreadsheet = client.open("known_samples")
573
+ # sheet = spreadsheet.sheet1
574
+
575
+ # data = sheet.get_all_values()
576
+ # if not data:
577
+ # print("⚠️ Google Sheet 'known_samples' is empty.")
578
+ # return None
579
+
580
+ # df = pd.DataFrame(data[1:], columns=data[0])
581
+ # if "Sample ID" not in df.columns:
582
+ # print("❌ Column 'Sample ID' not found in Google Sheet.")
583
+ # return None
584
+
585
+ # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
586
+ # if match:
587
+ # accession = match.group(0)
588
+
589
+ # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
590
+ # if not matched.empty:
591
+ # #return matched.iloc[0].to_dict()
592
+ # row = matched.iloc[0]
593
+ # country = row.get("Predicted Country", "").strip().lower()
594
+ # sample_type = row.get("Predicted Sample Type", "").strip().lower()
595
+
596
+ # if country and country != "unknown" and sample_type and sample_type != "unknown":
597
+ # return row.to_dict()
598
+ # else:
599
+ # print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
600
+ # return None
601
+ # else:
602
+ # print(f"πŸ” Accession {accession} not found in known_samples.")
603
+ # return None
604
+
605
+ # except Exception as e:
606
+ # import traceback
607
+ # print("❌ Exception occurred during check_known_output:")
608
+ # traceback.print_exc()
609
+ # return None
610
+
611
+ import os
612
+ import re
613
+ import json
614
+ import time
615
+ import gspread
616
+ import pandas as pd
617
+ from oauth2client.service_account import ServiceAccountCredentials
618
+ from gspread.exceptions import APIError
619
+
620
+ # --- Global cache ---
621
+ _known_samples_cache = None
622
+
623
+ def load_known_samples():
624
+ """Load the Google Sheet 'known_samples' into a Pandas DataFrame and cache it."""
625
+ global _known_samples_cache
626
+ try:
627
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
628
+ scope = [
629
+ 'https://spreadsheets.google.com/feeds',
630
+ 'https://www.googleapis.com/auth/drive'
631
+ ]
632
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
633
+ client = gspread.authorize(creds)
634
+
635
+ sheet = client.open("known_samples").sheet1
636
+ data = sheet.get_all_values()
637
+
638
+ if not data:
639
+ print("⚠️ Google Sheet 'known_samples' is empty.")
640
+ _known_samples_cache = pd.DataFrame()
641
+ else:
642
+ _known_samples_cache = pd.DataFrame(data[1:], columns=data[0])
643
+ print(f"βœ… Cached {_known_samples_cache.shape[0]} rows from known_samples")
644
+
645
+ except APIError as e:
646
+ print(f"❌ APIError while loading known_samples: {e}")
647
+ _known_samples_cache = pd.DataFrame()
648
+ except Exception as e:
649
+ import traceback
650
+ print("❌ Exception occurred while loading known_samples:")
651
+ traceback.print_exc()
652
+ _known_samples_cache = pd.DataFrame()
653
+
654
+ def check_known_output(accession):
655
+ """Check if an accession exists in the cached 'known_samples' sheet."""
656
+ global _known_samples_cache
657
+ print("inside check known output function")
658
+
659
+ try:
660
+ # Load cache if not already loaded
661
+ if _known_samples_cache is None:
662
+ load_known_samples()
663
+
664
+ if _known_samples_cache.empty:
665
+ print("⚠️ No cached data available.")
666
+ return None
667
+
668
+ # Extract proper accession format (e.g. AB12345)
669
+ match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
670
+ if match:
671
+ accession = match.group(0)
672
+
673
+ matched = _known_samples_cache[
674
+ _known_samples_cache["Sample ID"].str.contains(accession, case=False, na=False)
675
+ ]
676
+
677
+ if not matched.empty:
678
+ row = matched.iloc[0]
679
+ country = row.get("Predicted Country", "").strip().lower()
680
+ sample_type = row.get("Predicted Sample Type", "").strip().lower()
681
+
682
+ if country and country != "unknown" and sample_type and sample_type != "unknown":
683
+ print(f"🎯 Found {accession} in cache")
684
+ return row.to_dict()
685
+ else:
686
+ print(f"⚠️ Accession {accession} found but country/sample_type unknown or empty.")
687
+ return None
688
+ else:
689
+ print(f"πŸ” Accession {accession} not found in cache.")
690
+ return None
691
+
692
+ except Exception as e:
693
+ import traceback
694
+ print("❌ Exception occurred during check_known_output:")
695
+ traceback.print_exc()
696
+ return None
697
+
698
+
699
+
700
+ def hash_user_id(user_input):
701
+ return hashlib.sha256(user_input.encode()).hexdigest()
702
+
703
+ # βœ… Load and save usage count
704
+
705
+ # def load_user_usage():
706
+ # if not os.path.exists(USER_USAGE_TRACK_FILE):
707
+ # return {}
708
+
709
+ # try:
710
+ # with open(USER_USAGE_TRACK_FILE, "r") as f:
711
+ # content = f.read().strip()
712
+ # if not content:
713
+ # return {} # file is empty
714
+ # return json.loads(content)
715
+ # except (json.JSONDecodeError, ValueError):
716
+ # print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
717
+ # return {} # fallback to empty dict
718
+ # def load_user_usage():
719
+ # try:
720
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
721
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
722
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
723
+ # client = gspread.authorize(creds)
724
+
725
+ # sheet = client.open("user_usage_log").sheet1
726
+ # data = sheet.get_all_records() # Assumes columns: email, usage_count
727
+
728
+ # usage = {}
729
+ # for row in data:
730
+ # email = row.get("email", "").strip().lower()
731
+ # count = int(row.get("usage_count", 0))
732
+ # if email:
733
+ # usage[email] = count
734
+ # return usage
735
+ # except Exception as e:
736
+ # print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
737
+ # return {}
738
+ # def load_user_usage():
739
+ # try:
740
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
741
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
742
+
743
+ # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
744
+ # if not found:
745
+ # return {} # not found, start fresh
746
+
747
+ # #file_id = found[0]["id"]
748
+ # file_id = found
749
+ # content = pipeline.download_drive_file_content(file_id)
750
+ # return json.loads(content.strip()) if content.strip() else {}
751
+
752
+ # except Exception as e:
753
+ # print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
754
+ # return {}
755
+ def load_user_usage():
756
+ try:
757
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
758
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
759
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
760
+ client = gspread.authorize(creds)
761
+
762
+ sheet = client.open("user_usage_log").sheet1
763
+ data = sheet.get_all_values()
764
+ print("data: ", data)
765
+ print("πŸ§ͺ Raw header row from sheet:", data[0])
766
+ print("πŸ§ͺ Character codes in each header:")
767
+ for h in data[0]:
768
+ print([ord(c) for c in h])
769
+
770
+ if not data or len(data) < 2:
771
+ print("⚠️ Sheet is empty or missing rows.")
772
+ return {}
773
+
774
+ headers = [h.strip().lower() for h in data[0]]
775
+ if "email" not in headers or "usage_count" not in headers:
776
+ print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
777
+ return {}
778
+
779
+ permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
780
+ df = pd.DataFrame(data[1:], columns=headers)
781
+
782
+ usage = {}
783
+ permitted = {}
784
+ for _, row in df.iterrows():
785
+ email = row.get("email", "").strip().lower()
786
+ try:
787
+ #count = int(row.get("usage_count", 0))
788
+ try:
789
+ count = int(float(row.get("usage_count", 0)))
790
+ except Exception:
791
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
792
+ count = 0
793
+
794
+ if email:
795
+ usage[email] = count
796
+ if permitted_index is not None:
797
+ try:
798
+ permitted_count = int(float(row.get("permitted_samples", 50)))
799
+ permitted[email] = permitted_count
800
+ except:
801
+ permitted[email] = 50
802
+
803
+ except ValueError:
804
+ print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
805
+ return usage, permitted
806
+
807
+ except Exception as e:
808
+ print(f"❌ Error in load_user_usage: {e}")
809
+ return {}, {}
810
+
811
+
812
+
813
+ # def save_user_usage(usage):
814
+ # with open(USER_USAGE_TRACK_FILE, "w") as f:
815
+ # json.dump(usage, f, indent=2)
816
+
817
+ # def save_user_usage(usage_dict):
818
+ # try:
819
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
820
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
821
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
822
+ # client = gspread.authorize(creds)
823
+
824
+ # sheet = client.open("user_usage_log").sheet1
825
+ # sheet.clear() # clear old contents first
826
+
827
+ # # Write header + rows
828
+ # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
829
+ # sheet.update(rows)
830
+ # except Exception as e:
831
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
832
+ # def save_user_usage(usage_dict):
833
+ # try:
834
+ # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
835
+ # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
836
+
837
+ # import tempfile
838
+ # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
839
+ # print("πŸ’Ύ Saving this usage dict:", usage_dict)
840
+ # with open(tmp_path, "w") as f:
841
+ # json.dump(usage_dict, f, indent=2)
842
+
843
+ # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
844
+
845
+ # except Exception as e:
846
+ # print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
847
+ # def save_user_usage(usage_dict):
848
+ # try:
849
+ # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
850
+ # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
851
+ # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
852
+ # client = gspread.authorize(creds)
853
+
854
+ # spreadsheet = client.open("user_usage_log")
855
+ # sheet = spreadsheet.sheet1
856
+
857
+ # # Step 1: Convert new usage to DataFrame
858
+ # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
859
+ # df_new["email"] = df_new["email"].str.strip().str.lower()
860
+
861
+ # # Step 2: Load existing data
862
+ # existing_data = sheet.get_all_values()
863
+ # print("πŸ§ͺ Sheet existing_data:", existing_data)
864
+
865
+ # # Try to load old data
866
+ # if existing_data and len(existing_data[0]) >= 1:
867
+ # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
868
+
869
+ # # Fix missing columns
870
+ # if "email" not in df_old.columns:
871
+ # df_old["email"] = ""
872
+ # if "usage_count" not in df_old.columns:
873
+ # df_old["usage_count"] = 0
874
+
875
+ # df_old["email"] = df_old["email"].str.strip().str.lower()
876
+ # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
877
+ # else:
878
+ # df_old = pd.DataFrame(columns=["email", "usage_count"])
879
+
880
+ # # Step 3: Merge
881
+ # df_combined = pd.concat([df_old, df_new], ignore_index=True)
882
+ # df_combined = df_combined.groupby("email", as_index=False).sum()
883
+
884
+ # # Step 4: Write back
885
+ # sheet.clear()
886
+ # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
887
+ # print("βœ… Saved user usage to user_usage_log sheet.")
888
+
889
+ # except Exception as e:
890
+ # print(f"❌ Failed to save user usage to Google Sheets: {e}")
891
+ def save_user_usage(usage_dict):
892
+ try:
893
+ creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
894
+ scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
895
+ creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
896
+ client = gspread.authorize(creds)
897
+
898
+ spreadsheet = client.open("user_usage_log")
899
+ sheet = spreadsheet.sheet1
900
+
901
+ # Build new df
902
+ df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
903
+ df_new["email"] = df_new["email"].str.strip().str.lower()
904
+ df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
905
+
906
+ # Read existing data
907
+ existing_data = sheet.get_all_values()
908
+ if existing_data and len(existing_data[0]) >= 2:
909
+ df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
910
+ df_old["email"] = df_old["email"].str.strip().str.lower()
911
+ df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
912
+ else:
913
+ df_old = pd.DataFrame(columns=["email", "usage_count"])
914
+
915
+ # βœ… Overwrite specific emails only
916
+ df_old = df_old.set_index("email")
917
+ for email, count in usage_dict.items():
918
+ email = email.strip().lower()
919
+ df_old.loc[email, "usage_count"] = count
920
+ df_old = df_old.reset_index()
921
+
922
+ # Save
923
+ sheet.clear()
924
+ sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
925
+ print("βœ… Saved user usage to user_usage_log sheet.")
926
+
927
+ except Exception as e:
928
+ print(f"❌ Failed to save user usage to Google Sheets: {e}")
929
+
930
+
931
+
932
+
933
+ # def increment_usage(user_id, num_samples=1):
934
+ # usage = load_user_usage()
935
+ # if user_id not in usage:
936
+ # usage[user_id] = 0
937
+ # usage[user_id] += num_samples
938
+ # save_user_usage(usage)
939
+ # return usage[user_id]
940
+ # def increment_usage(email: str, count: int):
941
+ # usage = load_user_usage()
942
+ # email_key = email.strip().lower()
943
+ # usage[email_key] = usage.get(email_key, 0) + count
944
+ # save_user_usage(usage)
945
+ # return usage[email_key]
946
+ def increment_usage(email: str, count: int = 1):
947
+ usage, permitted = load_user_usage()
948
+ email_key = email.strip().lower()
949
+ #usage[email_key] = usage.get(email_key, 0) + count
950
+ current = usage.get(email_key, 0)
951
+ new_value = current + count
952
+ max_allowed = permitted.get(email_key) or 50
953
+ usage[email_key] = max(current, new_value) # βœ… Prevent overwrite with lower
954
+ print(f"πŸ§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
955
+ print("max allow is: ", max_allowed)
956
+ save_user_usage(usage)
957
+ return usage[email_key], max_allowed
958
+
959
+
960
+ # run the batch
961
+ def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
962
+ stop_flag=None, output_file_path=None,
963
+ limited_acc=50, yield_callback=None):
964
+ if user_email:
965
+ limited_acc += 10
966
+ accessions, error = extract_accessions_from_input(file, raw_text)
967
+ if error:
968
+ #return [], "", "", f"Error: {error}"
969
+ return [], f"Error: {error}", 0, "", ""
970
+ if resume_file:
971
+ accessions = get_incomplete_accessions(resume_file)
972
+ tmp_dir = tempfile.mkdtemp()
973
+ if not output_file_path:
974
+ if resume_file:
975
+ output_file_path = os.path.join(tmp_dir, resume_file)
976
+ else:
977
+ output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
978
+
979
+ all_rows = []
980
+ # all_summaries = []
981
+ # all_flags = []
982
+ progress_lines = []
983
+ warning = ""
984
+ if len(accessions) > limited_acc:
985
+ accessions = accessions[:limited_acc]
986
+ warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
987
+ for i, acc in enumerate(accessions):
988
+ if stop_flag and stop_flag.value:
989
+ line = f"πŸ›‘ Stopped at {acc} ({i+1}/{len(accessions)})"
990
+ progress_lines.append(line)
991
+ if yield_callback:
992
+ yield_callback(line)
993
+ print("πŸ›‘ User requested stop.")
994
+ break
995
+ print(f"[{i+1}/{len(accessions)}] Processing {acc}")
996
+ try:
997
+ # rows, summary, label, explain = summarize_results(acc)
998
+ rows = summarize_results(acc)
999
+ all_rows.extend(rows)
1000
+ # all_summaries.append(f"**{acc}**\n{summary}")
1001
+ # all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
1002
+ #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
1003
+ save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
1004
+ line = f"βœ… Processed {acc} ({i+1}/{len(accessions)})"
1005
+ progress_lines.append(line)
1006
+ if yield_callback:
1007
+ yield_callback(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
1008
+ except Exception as e:
1009
+ print(f"❌ Failed to process {acc}: {e}")
1010
+ continue
1011
+ #all_summaries.append(f"**{acc}**: Failed - {e}")
1012
+ #progress_lines.append(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
1013
+ limited_acc -= 1
1014
+ """for row in all_rows:
1015
+ source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
1016
+
1017
+ if source_column.startswith("http"): # Check if the source is a URL
1018
+ # Wrap it with HTML anchor tags to make it clickable
1019
+ row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
1020
+ if not warning:
1021
+ warning = f"You only have {limited_acc} left"
1022
+ if user_email.strip():
1023
+ user_hash = hash_user_id(user_email)
1024
+ total_queries = increment_usage(user_hash, len(all_rows))
1025
+ else:
1026
+ total_queries = 0
1027
+ yield_callback("βœ… Finished!")
1028
+
1029
+ # summary_text = "\n\n---\n\n".join(all_summaries)
1030
+ # flag_text = "\n\n---\n\n".join(all_flags)
1031
+ #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
1032
+ #return all_rows, gr.update(visible=True), gr.update(visible=False)
1033
  return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning