Spaces:
Running
Running
import gradio as gr | |
from collections import Counter | |
import csv | |
import os | |
from functools import lru_cache | |
#import app | |
from mtdna_classifier import classify_sample_location | |
import data_preprocess, model, pipeline | |
import subprocess | |
import json | |
import pandas as pd | |
import io | |
import re | |
import tempfile | |
import gspread | |
from oauth2client.service_account import ServiceAccountCredentials | |
from io import StringIO | |
import hashlib | |
import threading | |
# @lru_cache(maxsize=3600) | |
# def classify_sample_location_cached(accession): | |
# return classify_sample_location(accession) | |
#@lru_cache(maxsize=3600) | |
def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None): | |
print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession]) | |
print("len of save df: ", len(save_df)) | |
return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df) | |
# Count and suggest final location | |
# def compute_final_suggested_location(rows): | |
# candidates = [ | |
# row.get("Predicted Location", "").strip() | |
# for row in rows | |
# if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"] | |
# ] + [ | |
# row.get("Inferred Region", "").strip() | |
# for row in rows | |
# if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"] | |
# ] | |
# if not candidates: | |
# return Counter(), ("Unknown", 0) | |
# # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc. | |
# tokens = [] | |
# for item in candidates: | |
# # Split by comma, whitespace, and newlines | |
# parts = re.split(r'[\s,]+', item) | |
# tokens.extend(parts) | |
# # Step 2: Clean and normalize tokens | |
# tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens | |
# # Step 3: Count | |
# counts = Counter(tokens) | |
# # Step 4: Get most common | |
# top_location, count = counts.most_common(1)[0] | |
# return counts, (top_location, count) | |
# Store feedback (with required fields) | |
def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""): | |
if not answer1.strip() or not answer2.strip(): | |
return "β οΈ Please answer both questions before submitting." | |
try: | |
# β Step: Load credentials from Hugging Face secret | |
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
# Connect to Google Sheet | |
client = gspread.authorize(creds) | |
sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches | |
# Append feedback | |
sheet.append_row([accession, answer1, answer2, contact]) | |
return "β Feedback submitted. Thank you!" | |
except Exception as e: | |
return f"β Error submitting feedback: {e}" | |
# helper function to extract accessions | |
def extract_accessions_from_input(file=None, raw_text=""): | |
print(f"RAW TEXT RECEIVED: {raw_text}") | |
accessions = [] | |
seen = set() | |
if file: | |
try: | |
if file.name.endswith(".csv"): | |
df = pd.read_csv(file) | |
elif file.name.endswith(".xlsx"): | |
df = pd.read_excel(file) | |
else: | |
return [], "Unsupported file format. Please upload CSV or Excel." | |
for acc in df.iloc[:, 0].dropna().astype(str).str.strip(): | |
if acc not in seen: | |
accessions.append(acc) | |
seen.add(acc) | |
except Exception as e: | |
return [], f"Failed to read file: {e}" | |
if raw_text: | |
text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()] | |
for acc in text_ids: | |
if acc not in seen: | |
accessions.append(acc) | |
seen.add(acc) | |
return list(accessions), None | |
# β Add a new helper to backend: `filter_unprocessed_accessions()` | |
def get_incomplete_accessions(file_path): | |
df = pd.read_excel(file_path) | |
incomplete_accessions = [] | |
for _, row in df.iterrows(): | |
sample_id = str(row.get("Sample ID", "")).strip() | |
# Skip if no sample ID | |
if not sample_id: | |
continue | |
# Drop the Sample ID and check if the rest is empty | |
other_cols = row.drop(labels=["Sample ID"], errors="ignore") | |
if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all(): | |
# Extract the accession number from the sample ID using regex | |
match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id) | |
if match: | |
incomplete_accessions.append(match.group(0)) | |
print(len(incomplete_accessions)) | |
return incomplete_accessions | |
# GOOGLE_SHEET_NAME = "known_samples" | |
# USAGE_DRIVE_FILENAME = "user_usage_log.json" | |
def summarize_results(accession, stop_flag=None): | |
# Early bail | |
if stop_flag is not None and stop_flag.value: | |
print(f"π Skipping {accession} before starting.") | |
return [] | |
# try cache first | |
cached = check_known_output(accession) | |
if cached: | |
print(f"β Using cached result for {accession}") | |
return [[ | |
cached["Sample ID"] or "unknown", | |
cached["Predicted Country"] or "unknown", | |
cached["Country Explanation"] or "unknown", | |
cached["Predicted Sample Type"] or "unknown", | |
cached["Sample Type Explanation"] or "unknown", | |
cached["Sources"] or "No Links", | |
cached["Time cost"] | |
]] | |
# only run when nothing in the cache | |
try: | |
print("try gemini pipeline: ",accession) | |
# β Load credentials from Hugging Face secret | |
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
client = gspread.authorize(creds) | |
spreadsheet = client.open("known_samples") | |
sheet = spreadsheet.sheet1 | |
data = sheet.get_all_values() | |
if not data: | |
print("β οΈ Google Sheet 'known_samples' is empty.") | |
return None | |
save_df = pd.DataFrame(data[1:], columns=data[0]) | |
print("before pipeline, len of save df: ", len(save_df)) | |
outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df) | |
if stop_flag is not None and stop_flag.value: | |
print(f"π Skipped {accession} mid-pipeline.") | |
return [] | |
# outputs = {'KU131308': {'isolate':'BRU18', | |
# 'country': {'brunei': ['ncbi', | |
# 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, | |
# 'sample_type': {'modern': | |
# ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, | |
# 'query_cost': 9.754999999999999e-05, | |
# 'time_cost': '24.776 seconds', | |
# 'source': ['https://doi.org/10.1007/s00439-015-1620-z', | |
# 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf', | |
# 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}} | |
except Exception as e: | |
return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}" | |
if accession not in outputs: | |
print("no accession in output ", accession) | |
return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results." | |
row_score = [] | |
rows = [] | |
save_rows = [] | |
for key in outputs: | |
pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown" | |
for section, results in outputs[key].items(): | |
if section == "country" or section =="sample_type": | |
pred_output = []#"\n".join(list(results.keys())) | |
output_explanation = "" | |
for result, content in results.items(): | |
if len(result) == 0: result = "unknown" | |
if len(content) == 0: output_explanation = "unknown" | |
else: | |
output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n" | |
pred_output.append(result) | |
pred_output = "\n".join(pred_output) | |
if section == "country": | |
pred_country, country_explanation = pred_output, output_explanation | |
elif section == "sample_type": | |
pred_sample, sample_explanation = pred_output, output_explanation | |
if outputs[key]["isolate"].lower()!="unknown": | |
label = key + "(Isolate: " + outputs[key]["isolate"] + ")" | |
else: label = key | |
if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"] | |
row = { | |
"Sample ID": label or "unknown", | |
"Predicted Country": pred_country or "unknown", | |
"Country Explanation": country_explanation or "unknown", | |
"Predicted Sample Type":pred_sample or "unknown", | |
"Sample Type Explanation":sample_explanation or "unknown", | |
"Sources": "\n".join(outputs[key]["source"]) or "No Links", | |
"Time cost": outputs[key]["time_cost"] | |
} | |
#row_score.append(row) | |
rows.append(list(row.values())) | |
save_row = { | |
"Sample ID": label or "unknown", | |
"Predicted Country": pred_country or "unknown", | |
"Country Explanation": country_explanation or "unknown", | |
"Predicted Sample Type":pred_sample or "unknown", | |
"Sample Type Explanation":sample_explanation or "unknown", | |
"Sources": "\n".join(outputs[key]["source"]) or "No Links", | |
"Query_cost": outputs[key]["query_cost"] or "", | |
"Time cost": outputs[key]["time_cost"] or "", | |
"file_chunk":outputs[key]["file_chunk"] or "", | |
"file_all_output":outputs[key]["file_all_output"] or "" | |
} | |
#row_score.append(row) | |
save_rows.append(list(save_row.values())) | |
# #location_counts, (final_location, count) = compute_final_suggested_location(row_score) | |
# summary_lines = [f"### π§ Location Summary:\n"] | |
# summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()] | |
# summary_lines.append(f"\n**Final Suggested Location:** πΊοΈ **{final_location}** (mentioned {count} times)") | |
# summary = "\n".join(summary_lines) | |
# save the new running sample to known excel file | |
# try: | |
# df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"]) | |
# if os.path.exists(KNOWN_OUTPUT_PATH): | |
# df_old = pd.read_excel(KNOWN_OUTPUT_PATH) | |
# df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") | |
# else: | |
# df_combined = df_new | |
# df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False) | |
# except Exception as e: | |
# print(f"β οΈ Failed to save known output: {e}") | |
# try: | |
# df_new = pd.DataFrame(save_rows, columns=[ | |
# "Sample ID", "Predicted Country", "Country Explanation", | |
# "Predicted Sample Type", "Sample Type Explanation", | |
# "Sources", "Query_cost", "Time cost" | |
# ]) | |
# # β Google Sheets API setup | |
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
# client = gspread.authorize(creds) | |
# # β Open the known_samples sheet | |
# spreadsheet = client.open("known_samples") # Replace with your sheet name | |
# sheet = spreadsheet.sheet1 | |
# # β Read old data | |
# existing_data = sheet.get_all_values() | |
# if existing_data: | |
# df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
# else: | |
# df_old = pd.DataFrame(columns=df_new.columns) | |
# # β Combine and remove duplicates | |
# df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID") | |
# # β Clear and write back | |
# sheet.clear() | |
# sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist()) | |
# except Exception as e: | |
# print(f"β οΈ Failed to save known output to Google Sheets: {e}") | |
try: | |
# Prepare as DataFrame | |
df_new = pd.DataFrame(save_rows, columns=[ | |
"Sample ID", "Predicted Country", "Country Explanation", | |
"Predicted Sample Type", "Sample Type Explanation", | |
"Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output" | |
]) | |
# β Setup Google Sheets | |
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
client = gspread.authorize(creds) | |
spreadsheet = client.open("known_samples") | |
sheet = spreadsheet.sheet1 | |
# β Read existing data | |
existing_data = sheet.get_all_values() | |
if existing_data: | |
df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
else: | |
df_old = pd.DataFrame(columns=[ | |
"Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation", | |
"Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type", | |
"Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output" | |
]) | |
# β Index by Sample ID | |
df_old.set_index("Sample ID", inplace=True) | |
df_new.set_index("Sample ID", inplace=True) | |
# β Update only matching fields | |
update_columns = [ | |
"Predicted Country", "Predicted Sample Type", "Country Explanation", | |
"Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output" | |
] | |
for idx, row in df_new.iterrows(): | |
if idx not in df_old.index: | |
df_old.loc[idx] = "" # new row, fill empty first | |
for col in update_columns: | |
if pd.notna(row[col]) and row[col] != "": | |
df_old.at[idx, col] = row[col] | |
# β Reset and write back | |
df_old.reset_index(inplace=True) | |
sheet.clear() | |
sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist()) | |
print("β Match results saved to known_samples.") | |
except Exception as e: | |
print(f"β Failed to update known_samples: {e}") | |
return rows#, summary, labelAncient_Modern, explain_label | |
# save the batch input in excel file | |
# def save_to_excel(all_rows, summary_text, flag_text, filename): | |
# with pd.ExcelWriter(filename) as writer: | |
# # Save table | |
# df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) | |
# df.to_excel(writer, sheet_name="Detailed Results", index=False) | |
# try: | |
# df_old = pd.read_excel(filename) | |
# except: | |
# df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) | |
# df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") | |
# # if os.path.exists(filename): | |
# # df_old = pd.read_excel(filename) | |
# # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") | |
# # else: | |
# # df_combined = df_new | |
# df_combined.to_excel(filename, index=False) | |
# # # Save summary | |
# # summary_df = pd.DataFrame({"Summary": [summary_text]}) | |
# # summary_df.to_excel(writer, sheet_name="Summary", index=False) | |
# # # Save flag | |
# # flag_df = pd.DataFrame({"Flag": [flag_text]}) | |
# # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False) | |
# def save_to_excel(all_rows, summary_text, flag_text, filename): | |
# df_new = pd.DataFrame(all_rows, columns=[ | |
# "Sample ID", "Predicted Country", "Country Explanation", | |
# "Predicted Sample Type", "Sample Type Explanation", | |
# "Sources", "Time cost" | |
# ]) | |
# try: | |
# if os.path.exists(filename): | |
# df_old = pd.read_excel(filename) | |
# else: | |
# df_old = pd.DataFrame(columns=df_new.columns) | |
# except Exception as e: | |
# print(f"β οΈ Warning reading old Excel file: {e}") | |
# df_old = pd.DataFrame(columns=df_new.columns) | |
# #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first") | |
# df_old.set_index("Sample ID", inplace=True) | |
# df_new.set_index("Sample ID", inplace=True) | |
# df_old.update(df_new) # <-- update matching rows in df_old with df_new content | |
# df_combined = df_old.reset_index() | |
# try: | |
# df_combined.to_excel(filename, index=False) | |
# except Exception as e: | |
# print(f"β Failed to write Excel file {filename}: {e}") | |
def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False): | |
df_new = pd.DataFrame(all_rows, columns=[ | |
"Sample ID", "Predicted Country", "Country Explanation", | |
"Predicted Sample Type", "Sample Type Explanation", | |
"Sources", "Time cost" | |
]) | |
if is_resume and os.path.exists(filename): | |
try: | |
df_old = pd.read_excel(filename) | |
except Exception as e: | |
print(f"β οΈ Warning reading old Excel file: {e}") | |
df_old = pd.DataFrame(columns=df_new.columns) | |
# Set index and update existing rows | |
df_old.set_index("Sample ID", inplace=True) | |
df_new.set_index("Sample ID", inplace=True) | |
df_old.update(df_new) | |
df_combined = df_old.reset_index() | |
else: | |
# If not resuming or file doesn't exist, just use new rows | |
df_combined = df_new | |
try: | |
df_combined.to_excel(filename, index=False) | |
except Exception as e: | |
print(f"β Failed to write Excel file {filename}: {e}") | |
# save the batch input in JSON file | |
def save_to_json(all_rows, summary_text, flag_text, filename): | |
output_dict = { | |
"Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame | |
# "Summary_Text": summary_text, | |
# "Ancient_Modern_Flag": flag_text | |
} | |
# If all_rows is a DataFrame, convert it | |
if isinstance(all_rows, pd.DataFrame): | |
output_dict["Detailed_Results"] = all_rows.to_dict(orient="records") | |
with open(filename, "w") as external_file: | |
json.dump(output_dict, external_file, indent=2) | |
# save the batch input in Text file | |
def save_to_txt(all_rows, summary_text, flag_text, filename): | |
if isinstance(all_rows, pd.DataFrame): | |
detailed_results = all_rows.to_dict(orient="records") | |
output = "" | |
#output += ",".join(list(detailed_results[0].keys())) + "\n\n" | |
output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n" | |
for r in detailed_results: | |
output += ",".join([str(v) for v in r.values()]) + "\n\n" | |
with open(filename, "w") as f: | |
f.write("=== Detailed Results ===\n") | |
f.write(output + "\n") | |
# f.write("\n=== Summary ===\n") | |
# f.write(summary_text + "\n") | |
# f.write("\n=== Ancient/Modern Flag ===\n") | |
# f.write(flag_text + "\n") | |
def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None): | |
tmp_dir = tempfile.mkdtemp() | |
#html_table = all_rows.value # assuming this is stored somewhere | |
# Parse back to DataFrame | |
#all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list | |
all_rows = pd.read_html(StringIO(all_rows))[0] | |
print(all_rows) | |
if output_type == "Excel": | |
file_path = f"{tmp_dir}/batch_output.xlsx" | |
save_to_excel(all_rows, summary_text, flag_text, file_path) | |
elif output_type == "JSON": | |
file_path = f"{tmp_dir}/batch_output.json" | |
save_to_json(all_rows, summary_text, flag_text, file_path) | |
print("Done with JSON") | |
elif output_type == "TXT": | |
file_path = f"{tmp_dir}/batch_output.txt" | |
save_to_txt(all_rows, summary_text, flag_text, file_path) | |
else: | |
return gr.update(visible=False) # invalid option | |
return gr.update(value=file_path, visible=True) | |
# save cost by checking the known outputs | |
# def check_known_output(accession): | |
# if not os.path.exists(KNOWN_OUTPUT_PATH): | |
# return None | |
# try: | |
# df = pd.read_excel(KNOWN_OUTPUT_PATH) | |
# match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
# if match: | |
# accession = match.group(0) | |
# matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] | |
# if not matched.empty: | |
# return matched.iloc[0].to_dict() # Return the cached row | |
# except Exception as e: | |
# print(f"β οΈ Failed to load known samples: {e}") | |
# return None | |
# def check_known_output(accession): | |
# try: | |
# # β Load credentials from Hugging Face secret | |
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
# client = gspread.authorize(creds) | |
# # β Open the known_samples sheet | |
# spreadsheet = client.open("known_samples") # Replace with your sheet name | |
# sheet = spreadsheet.sheet1 | |
# # β Read all rows | |
# data = sheet.get_all_values() | |
# if not data: | |
# return None | |
# df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row | |
# # β Normalize accession pattern | |
# match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
# if match: | |
# accession = match.group(0) | |
# matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] | |
# if not matched.empty: | |
# return matched.iloc[0].to_dict() | |
# except Exception as e: | |
# print(f"β οΈ Failed to load known samples from Google Sheets: {e}") | |
# return None | |
def check_known_output(accession): | |
print("inside check known output function") | |
try: | |
# β Load credentials from Hugging Face secret | |
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
client = gspread.authorize(creds) | |
spreadsheet = client.open("known_samples") | |
sheet = spreadsheet.sheet1 | |
data = sheet.get_all_values() | |
if not data: | |
print("β οΈ Google Sheet 'known_samples' is empty.") | |
return None | |
df = pd.DataFrame(data[1:], columns=data[0]) | |
if "Sample ID" not in df.columns: | |
print("β Column 'Sample ID' not found in Google Sheet.") | |
return None | |
match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
if match: | |
accession = match.group(0) | |
matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] | |
if not matched.empty: | |
#return matched.iloc[0].to_dict() | |
row = matched.iloc[0] | |
country = row.get("Predicted Country", "").strip().lower() | |
sample_type = row.get("Predicted Sample Type", "").strip().lower() | |
if country and country != "unknown" and sample_type and sample_type != "unknown": | |
return row.to_dict() | |
else: | |
print(f"β οΈ Accession {accession} found but country/sample_type is unknown or empty.") | |
return None | |
else: | |
print(f"π Accession {accession} not found in known_samples.") | |
return None | |
except Exception as e: | |
import traceback | |
print("β Exception occurred during check_known_output:") | |
traceback.print_exc() | |
return None | |
def hash_user_id(user_input): | |
return hashlib.sha256(user_input.encode()).hexdigest() | |
# β Load and save usage count | |
# def load_user_usage(): | |
# if not os.path.exists(USER_USAGE_TRACK_FILE): | |
# return {} | |
# try: | |
# with open(USER_USAGE_TRACK_FILE, "r") as f: | |
# content = f.read().strip() | |
# if not content: | |
# return {} # file is empty | |
# return json.loads(content) | |
# except (json.JSONDecodeError, ValueError): | |
# print("β οΈ Warning: user_usage.json is corrupted or invalid. Resetting.") | |
# return {} # fallback to empty dict | |
# def load_user_usage(): | |
# try: | |
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
# client = gspread.authorize(creds) | |
# sheet = client.open("user_usage_log").sheet1 | |
# data = sheet.get_all_records() # Assumes columns: email, usage_count | |
# usage = {} | |
# for row in data: | |
# email = row.get("email", "").strip().lower() | |
# count = int(row.get("usage_count", 0)) | |
# if email: | |
# usage[email] = count | |
# return usage | |
# except Exception as e: | |
# print(f"β οΈ Failed to load user usage from Google Sheets: {e}") | |
# return {} | |
# def load_user_usage(): | |
# try: | |
# parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier") | |
# iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id) | |
# found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id) | |
# if not found: | |
# return {} # not found, start fresh | |
# #file_id = found[0]["id"] | |
# file_id = found | |
# content = pipeline.download_drive_file_content(file_id) | |
# return json.loads(content.strip()) if content.strip() else {} | |
# except Exception as e: | |
# print(f"β οΈ Failed to load user_usage_log.json from Google Drive: {e}") | |
# return {} | |
def load_user_usage(): | |
try: | |
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
client = gspread.authorize(creds) | |
sheet = client.open("user_usage_log").sheet1 | |
data = sheet.get_all_values() | |
print("data: ", data) | |
print("π§ͺ Raw header row from sheet:", data[0]) | |
print("π§ͺ Character codes in each header:") | |
for h in data[0]: | |
print([ord(c) for c in h]) | |
if not data or len(data) < 2: | |
print("β οΈ Sheet is empty or missing rows.") | |
return {} | |
headers = [h.strip().lower() for h in data[0]] | |
if "email" not in headers or "usage_count" not in headers: | |
print("β Header format incorrect. Must have 'email' and 'usage_count'.") | |
return {} | |
permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None | |
df = pd.DataFrame(data[1:], columns=headers) | |
usage = {} | |
permitted = {} | |
for _, row in df.iterrows(): | |
email = row.get("email", "").strip().lower() | |
try: | |
#count = int(row.get("usage_count", 0)) | |
try: | |
count = int(float(row.get("usage_count", 0))) | |
except Exception: | |
print(f"β οΈ Invalid usage_count for {email}: {row.get('usage_count')}") | |
count = 0 | |
if email: | |
usage[email] = count | |
if permitted_index is not None: | |
try: | |
permitted_count = int(float(row.get("permitted_samples", 50))) | |
permitted[email] = permitted_count | |
except: | |
permitted[email] = 50 | |
except ValueError: | |
print(f"β οΈ Invalid usage_count for {email}: {row.get('usage_count')}") | |
return usage, permitted | |
except Exception as e: | |
print(f"β Error in load_user_usage: {e}") | |
return {}, {} | |
# def save_user_usage(usage): | |
# with open(USER_USAGE_TRACK_FILE, "w") as f: | |
# json.dump(usage, f, indent=2) | |
# def save_user_usage(usage_dict): | |
# try: | |
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
# client = gspread.authorize(creds) | |
# sheet = client.open("user_usage_log").sheet1 | |
# sheet.clear() # clear old contents first | |
# # Write header + rows | |
# rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()] | |
# sheet.update(rows) | |
# except Exception as e: | |
# print(f"β Failed to save user usage to Google Sheets: {e}") | |
# def save_user_usage(usage_dict): | |
# try: | |
# parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier") | |
# iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id) | |
# import tempfile | |
# tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json") | |
# print("πΎ Saving this usage dict:", usage_dict) | |
# with open(tmp_path, "w") as f: | |
# json.dump(usage_dict, f, indent=2) | |
# pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id) | |
# except Exception as e: | |
# print(f"β Failed to save user_usage_log.json to Google Drive: {e}") | |
# def save_user_usage(usage_dict): | |
# try: | |
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
# client = gspread.authorize(creds) | |
# spreadsheet = client.open("user_usage_log") | |
# sheet = spreadsheet.sheet1 | |
# # Step 1: Convert new usage to DataFrame | |
# df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"]) | |
# df_new["email"] = df_new["email"].str.strip().str.lower() | |
# # Step 2: Load existing data | |
# existing_data = sheet.get_all_values() | |
# print("π§ͺ Sheet existing_data:", existing_data) | |
# # Try to load old data | |
# if existing_data and len(existing_data[0]) >= 1: | |
# df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
# # Fix missing columns | |
# if "email" not in df_old.columns: | |
# df_old["email"] = "" | |
# if "usage_count" not in df_old.columns: | |
# df_old["usage_count"] = 0 | |
# df_old["email"] = df_old["email"].str.strip().str.lower() | |
# df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int) | |
# else: | |
# df_old = pd.DataFrame(columns=["email", "usage_count"]) | |
# # Step 3: Merge | |
# df_combined = pd.concat([df_old, df_new], ignore_index=True) | |
# df_combined = df_combined.groupby("email", as_index=False).sum() | |
# # Step 4: Write back | |
# sheet.clear() | |
# sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist()) | |
# print("β Saved user usage to user_usage_log sheet.") | |
# except Exception as e: | |
# print(f"β Failed to save user usage to Google Sheets: {e}") | |
def save_user_usage(usage_dict): | |
try: | |
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
client = gspread.authorize(creds) | |
spreadsheet = client.open("user_usage_log") | |
sheet = spreadsheet.sheet1 | |
# Build new df | |
df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"]) | |
df_new["email"] = df_new["email"].str.strip().str.lower() | |
df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int) | |
# Read existing data | |
existing_data = sheet.get_all_values() | |
if existing_data and len(existing_data[0]) >= 2: | |
df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
df_old["email"] = df_old["email"].str.strip().str.lower() | |
df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int) | |
else: | |
df_old = pd.DataFrame(columns=["email", "usage_count"]) | |
# β Overwrite specific emails only | |
df_old = df_old.set_index("email") | |
for email, count in usage_dict.items(): | |
email = email.strip().lower() | |
df_old.loc[email, "usage_count"] = count | |
df_old = df_old.reset_index() | |
# Save | |
sheet.clear() | |
sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist()) | |
print("β Saved user usage to user_usage_log sheet.") | |
except Exception as e: | |
print(f"β Failed to save user usage to Google Sheets: {e}") | |
# def increment_usage(user_id, num_samples=1): | |
# usage = load_user_usage() | |
# if user_id not in usage: | |
# usage[user_id] = 0 | |
# usage[user_id] += num_samples | |
# save_user_usage(usage) | |
# return usage[user_id] | |
# def increment_usage(email: str, count: int): | |
# usage = load_user_usage() | |
# email_key = email.strip().lower() | |
# usage[email_key] = usage.get(email_key, 0) + count | |
# save_user_usage(usage) | |
# return usage[email_key] | |
def increment_usage(email: str, count: int = 1): | |
usage, permitted = load_user_usage() | |
email_key = email.strip().lower() | |
#usage[email_key] = usage.get(email_key, 0) + count | |
current = usage.get(email_key, 0) | |
new_value = current + count | |
max_allowed = permitted.get(email_key) or 50 | |
usage[email_key] = max(current, new_value) # β Prevent overwrite with lower | |
print(f"π§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}") | |
print("max allow is: ", max_allowed) | |
save_user_usage(usage) | |
return usage[email_key], max_allowed | |
# run the batch | |
def summarize_batch(file=None, raw_text="", resume_file=None, user_email="", | |
stop_flag=None, output_file_path=None, | |
limited_acc=50, yield_callback=None): | |
if user_email: | |
limited_acc += 10 | |
accessions, error = extract_accessions_from_input(file, raw_text) | |
if error: | |
#return [], "", "", f"Error: {error}" | |
return [], f"Error: {error}", 0, "", "" | |
if resume_file: | |
accessions = get_incomplete_accessions(resume_file) | |
tmp_dir = tempfile.mkdtemp() | |
if not output_file_path: | |
if resume_file: | |
output_file_path = os.path.join(tmp_dir, resume_file) | |
else: | |
output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx") | |
all_rows = [] | |
# all_summaries = [] | |
# all_flags = [] | |
progress_lines = [] | |
warning = "" | |
if len(accessions) > limited_acc: | |
accessions = accessions[:limited_acc] | |
warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions" | |
for i, acc in enumerate(accessions): | |
if stop_flag and stop_flag.value: | |
line = f"π Stopped at {acc} ({i+1}/{len(accessions)})" | |
progress_lines.append(line) | |
if yield_callback: | |
yield_callback(line) | |
print("π User requested stop.") | |
break | |
print(f"[{i+1}/{len(accessions)}] Processing {acc}") | |
try: | |
# rows, summary, label, explain = summarize_results(acc) | |
rows = summarize_results(acc) | |
all_rows.extend(rows) | |
# all_summaries.append(f"**{acc}**\n{summary}") | |
# all_flags.append(f"**{acc}**\n### πΊ Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}") | |
#save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path) | |
save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file)) | |
line = f"β Processed {acc} ({i+1}/{len(accessions)})" | |
progress_lines.append(line) | |
if yield_callback: | |
yield_callback(f"β Processed {acc} ({i+1}/{len(accessions)})") | |
except Exception as e: | |
print(f"β Failed to process {acc}: {e}") | |
continue | |
#all_summaries.append(f"**{acc}**: Failed - {e}") | |
#progress_lines.append(f"β Processed {acc} ({i+1}/{len(accessions)})") | |
limited_acc -= 1 | |
"""for row in all_rows: | |
source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2) | |
if source_column.startswith("http"): # Check if the source is a URL | |
# Wrap it with HTML anchor tags to make it clickable | |
row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'""" | |
if not warning: | |
warning = f"You only have {limited_acc} left" | |
if user_email.strip(): | |
user_hash = hash_user_id(user_email) | |
total_queries = increment_usage(user_hash, len(all_rows)) | |
else: | |
total_queries = 0 | |
yield_callback("β Finished!") | |
# summary_text = "\n\n---\n\n".join(all_summaries) | |
# flag_text = "\n\n---\n\n".join(all_flags) | |
#return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False) | |
#return all_rows, gr.update(visible=True), gr.update(visible=False) | |
return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning |