mtDNALocation / mtdna_backend.py
VyLala's picture
Update mtdna_backend.py
1ba0b1f verified
raw
history blame
39.3 kB
import gradio as gr
from collections import Counter
import csv
import os
from functools import lru_cache
#import app
from mtdna_classifier import classify_sample_location
import data_preprocess, model, pipeline
import subprocess
import json
import pandas as pd
import io
import re
import tempfile
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from io import StringIO
import hashlib
import threading
# @lru_cache(maxsize=3600)
# def classify_sample_location_cached(accession):
# return classify_sample_location(accession)
#@lru_cache(maxsize=3600)
def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None):
print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession])
print("len of save df: ", len(save_df))
return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df)
# Count and suggest final location
# def compute_final_suggested_location(rows):
# candidates = [
# row.get("Predicted Location", "").strip()
# for row in rows
# if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"]
# ] + [
# row.get("Inferred Region", "").strip()
# for row in rows
# if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"]
# ]
# if not candidates:
# return Counter(), ("Unknown", 0)
# # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc.
# tokens = []
# for item in candidates:
# # Split by comma, whitespace, and newlines
# parts = re.split(r'[\s,]+', item)
# tokens.extend(parts)
# # Step 2: Clean and normalize tokens
# tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens
# # Step 3: Count
# counts = Counter(tokens)
# # Step 4: Get most common
# top_location, count = counts.most_common(1)[0]
# return counts, (top_location, count)
# Store feedback (with required fields)
def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""):
if not answer1.strip() or not answer2.strip():
return "⚠️ Please answer both questions before submitting."
try:
# βœ… Step: Load credentials from Hugging Face secret
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# Connect to Google Sheet
client = gspread.authorize(creds)
sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches
# Append feedback
sheet.append_row([accession, answer1, answer2, contact])
return "βœ… Feedback submitted. Thank you!"
except Exception as e:
return f"❌ Error submitting feedback: {e}"
# helper function to extract accessions
def extract_accessions_from_input(file=None, raw_text=""):
print(f"RAW TEXT RECEIVED: {raw_text}")
accessions = []
seen = set()
if file:
try:
if file.name.endswith(".csv"):
df = pd.read_csv(file)
elif file.name.endswith(".xlsx"):
df = pd.read_excel(file)
else:
return [], "Unsupported file format. Please upload CSV or Excel."
for acc in df.iloc[:, 0].dropna().astype(str).str.strip():
if acc not in seen:
accessions.append(acc)
seen.add(acc)
except Exception as e:
return [], f"Failed to read file: {e}"
if raw_text:
text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()]
for acc in text_ids:
if acc not in seen:
accessions.append(acc)
seen.add(acc)
return list(accessions), None
# βœ… Add a new helper to backend: `filter_unprocessed_accessions()`
def get_incomplete_accessions(file_path):
df = pd.read_excel(file_path)
incomplete_accessions = []
for _, row in df.iterrows():
sample_id = str(row.get("Sample ID", "")).strip()
# Skip if no sample ID
if not sample_id:
continue
# Drop the Sample ID and check if the rest is empty
other_cols = row.drop(labels=["Sample ID"], errors="ignore")
if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all():
# Extract the accession number from the sample ID using regex
match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id)
if match:
incomplete_accessions.append(match.group(0))
print(len(incomplete_accessions))
return incomplete_accessions
# GOOGLE_SHEET_NAME = "known_samples"
# USAGE_DRIVE_FILENAME = "user_usage_log.json"
def summarize_results(accession, stop_flag=None):
# Early bail
if stop_flag is not None and stop_flag.value:
print(f"πŸ›‘ Skipping {accession} before starting.")
return []
# try cache first
cached = check_known_output(accession)
if cached:
print(f"βœ… Using cached result for {accession}")
return [[
cached["Sample ID"] or "unknown",
cached["Predicted Country"] or "unknown",
cached["Country Explanation"] or "unknown",
cached["Predicted Sample Type"] or "unknown",
cached["Sample Type Explanation"] or "unknown",
cached["Sources"] or "No Links",
cached["Time cost"]
]]
# only run when nothing in the cache
try:
print("try gemini pipeline: ",accession)
# βœ… Load credentials from Hugging Face secret
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
spreadsheet = client.open("known_samples")
sheet = spreadsheet.sheet1
data = sheet.get_all_values()
if not data:
print("⚠️ Google Sheet 'known_samples' is empty.")
return None
save_df = pd.DataFrame(data[1:], columns=data[0])
print("before pipeline, len of save df: ", len(save_df))
outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df)
if stop_flag is not None and stop_flag.value:
print(f"πŸ›‘ Skipped {accession} mid-pipeline.")
return []
# outputs = {'KU131308': {'isolate':'BRU18',
# 'country': {'brunei': ['ncbi',
# 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
# 'sample_type': {'modern':
# ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']},
# 'query_cost': 9.754999999999999e-05,
# 'time_cost': '24.776 seconds',
# 'source': ['https://doi.org/10.1007/s00439-015-1620-z',
# 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf',
# 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}}
except Exception as e:
return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}"
if accession not in outputs:
print("no accession in output ", accession)
return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results."
row_score = []
rows = []
save_rows = []
for key in outputs:
pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown"
for section, results in outputs[key].items():
if section == "country" or section =="sample_type":
pred_output = []#"\n".join(list(results.keys()))
output_explanation = ""
for result, content in results.items():
if len(result) == 0: result = "unknown"
if len(content) == 0: output_explanation = "unknown"
else:
output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n"
pred_output.append(result)
pred_output = "\n".join(pred_output)
if section == "country":
pred_country, country_explanation = pred_output, output_explanation
elif section == "sample_type":
pred_sample, sample_explanation = pred_output, output_explanation
if outputs[key]["isolate"].lower()!="unknown":
label = key + "(Isolate: " + outputs[key]["isolate"] + ")"
else: label = key
if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"]
row = {
"Sample ID": label or "unknown",
"Predicted Country": pred_country or "unknown",
"Country Explanation": country_explanation or "unknown",
"Predicted Sample Type":pred_sample or "unknown",
"Sample Type Explanation":sample_explanation or "unknown",
"Sources": "\n".join(outputs[key]["source"]) or "No Links",
"Time cost": outputs[key]["time_cost"]
}
#row_score.append(row)
rows.append(list(row.values()))
save_row = {
"Sample ID": label or "unknown",
"Predicted Country": pred_country or "unknown",
"Country Explanation": country_explanation or "unknown",
"Predicted Sample Type":pred_sample or "unknown",
"Sample Type Explanation":sample_explanation or "unknown",
"Sources": "\n".join(outputs[key]["source"]) or "No Links",
"Query_cost": outputs[key]["query_cost"] or "",
"Time cost": outputs[key]["time_cost"] or "",
"file_chunk":outputs[key]["file_chunk"] or "",
"file_all_output":outputs[key]["file_all_output"] or ""
}
#row_score.append(row)
save_rows.append(list(save_row.values()))
# #location_counts, (final_location, count) = compute_final_suggested_location(row_score)
# summary_lines = [f"### 🧭 Location Summary:\n"]
# summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()]
# summary_lines.append(f"\n**Final Suggested Location:** πŸ—ΊοΈ **{final_location}** (mentioned {count} times)")
# summary = "\n".join(summary_lines)
# save the new running sample to known excel file
# try:
# df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"])
# if os.path.exists(KNOWN_OUTPUT_PATH):
# df_old = pd.read_excel(KNOWN_OUTPUT_PATH)
# df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
# else:
# df_combined = df_new
# df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False)
# except Exception as e:
# print(f"⚠️ Failed to save known output: {e}")
# try:
# df_new = pd.DataFrame(save_rows, columns=[
# "Sample ID", "Predicted Country", "Country Explanation",
# "Predicted Sample Type", "Sample Type Explanation",
# "Sources", "Query_cost", "Time cost"
# ])
# # βœ… Google Sheets API setup
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# client = gspread.authorize(creds)
# # βœ… Open the known_samples sheet
# spreadsheet = client.open("known_samples") # Replace with your sheet name
# sheet = spreadsheet.sheet1
# # βœ… Read old data
# existing_data = sheet.get_all_values()
# if existing_data:
# df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
# else:
# df_old = pd.DataFrame(columns=df_new.columns)
# # βœ… Combine and remove duplicates
# df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID")
# # βœ… Clear and write back
# sheet.clear()
# sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist())
# except Exception as e:
# print(f"⚠️ Failed to save known output to Google Sheets: {e}")
try:
# Prepare as DataFrame
df_new = pd.DataFrame(save_rows, columns=[
"Sample ID", "Predicted Country", "Country Explanation",
"Predicted Sample Type", "Sample Type Explanation",
"Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
])
# βœ… Setup Google Sheets
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
spreadsheet = client.open("known_samples")
sheet = spreadsheet.sheet1
# βœ… Read existing data
existing_data = sheet.get_all_values()
if existing_data:
df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
else:
df_old = pd.DataFrame(columns=[
"Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation",
"Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type",
"Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output"
])
# βœ… Index by Sample ID
df_old.set_index("Sample ID", inplace=True)
df_new.set_index("Sample ID", inplace=True)
# βœ… Update only matching fields
update_columns = [
"Predicted Country", "Predicted Sample Type", "Country Explanation",
"Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output"
]
for idx, row in df_new.iterrows():
if idx not in df_old.index:
df_old.loc[idx] = "" # new row, fill empty first
for col in update_columns:
if pd.notna(row[col]) and row[col] != "":
df_old.at[idx, col] = row[col]
# βœ… Reset and write back
df_old.reset_index(inplace=True)
sheet.clear()
sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist())
print("βœ… Match results saved to known_samples.")
except Exception as e:
print(f"❌ Failed to update known_samples: {e}")
return rows#, summary, labelAncient_Modern, explain_label
# save the batch input in excel file
# def save_to_excel(all_rows, summary_text, flag_text, filename):
# with pd.ExcelWriter(filename) as writer:
# # Save table
# df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
# df.to_excel(writer, sheet_name="Detailed Results", index=False)
# try:
# df_old = pd.read_excel(filename)
# except:
# df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"])
# df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
# # if os.path.exists(filename):
# # df_old = pd.read_excel(filename)
# # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID")
# # else:
# # df_combined = df_new
# df_combined.to_excel(filename, index=False)
# # # Save summary
# # summary_df = pd.DataFrame({"Summary": [summary_text]})
# # summary_df.to_excel(writer, sheet_name="Summary", index=False)
# # # Save flag
# # flag_df = pd.DataFrame({"Flag": [flag_text]})
# # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False)
# def save_to_excel(all_rows, summary_text, flag_text, filename):
# df_new = pd.DataFrame(all_rows, columns=[
# "Sample ID", "Predicted Country", "Country Explanation",
# "Predicted Sample Type", "Sample Type Explanation",
# "Sources", "Time cost"
# ])
# try:
# if os.path.exists(filename):
# df_old = pd.read_excel(filename)
# else:
# df_old = pd.DataFrame(columns=df_new.columns)
# except Exception as e:
# print(f"⚠️ Warning reading old Excel file: {e}")
# df_old = pd.DataFrame(columns=df_new.columns)
# #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first")
# df_old.set_index("Sample ID", inplace=True)
# df_new.set_index("Sample ID", inplace=True)
# df_old.update(df_new) # <-- update matching rows in df_old with df_new content
# df_combined = df_old.reset_index()
# try:
# df_combined.to_excel(filename, index=False)
# except Exception as e:
# print(f"❌ Failed to write Excel file {filename}: {e}")
def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False):
df_new = pd.DataFrame(all_rows, columns=[
"Sample ID", "Predicted Country", "Country Explanation",
"Predicted Sample Type", "Sample Type Explanation",
"Sources", "Time cost"
])
if is_resume and os.path.exists(filename):
try:
df_old = pd.read_excel(filename)
except Exception as e:
print(f"⚠️ Warning reading old Excel file: {e}")
df_old = pd.DataFrame(columns=df_new.columns)
# Set index and update existing rows
df_old.set_index("Sample ID", inplace=True)
df_new.set_index("Sample ID", inplace=True)
df_old.update(df_new)
df_combined = df_old.reset_index()
else:
# If not resuming or file doesn't exist, just use new rows
df_combined = df_new
try:
df_combined.to_excel(filename, index=False)
except Exception as e:
print(f"❌ Failed to write Excel file {filename}: {e}")
# save the batch input in JSON file
def save_to_json(all_rows, summary_text, flag_text, filename):
output_dict = {
"Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame
# "Summary_Text": summary_text,
# "Ancient_Modern_Flag": flag_text
}
# If all_rows is a DataFrame, convert it
if isinstance(all_rows, pd.DataFrame):
output_dict["Detailed_Results"] = all_rows.to_dict(orient="records")
with open(filename, "w") as external_file:
json.dump(output_dict, external_file, indent=2)
# save the batch input in Text file
def save_to_txt(all_rows, summary_text, flag_text, filename):
if isinstance(all_rows, pd.DataFrame):
detailed_results = all_rows.to_dict(orient="records")
output = ""
#output += ",".join(list(detailed_results[0].keys())) + "\n\n"
output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n"
for r in detailed_results:
output += ",".join([str(v) for v in r.values()]) + "\n\n"
with open(filename, "w") as f:
f.write("=== Detailed Results ===\n")
f.write(output + "\n")
# f.write("\n=== Summary ===\n")
# f.write(summary_text + "\n")
# f.write("\n=== Ancient/Modern Flag ===\n")
# f.write(flag_text + "\n")
def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None):
tmp_dir = tempfile.mkdtemp()
#html_table = all_rows.value # assuming this is stored somewhere
# Parse back to DataFrame
#all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list
all_rows = pd.read_html(StringIO(all_rows))[0]
print(all_rows)
if output_type == "Excel":
file_path = f"{tmp_dir}/batch_output.xlsx"
save_to_excel(all_rows, summary_text, flag_text, file_path)
elif output_type == "JSON":
file_path = f"{tmp_dir}/batch_output.json"
save_to_json(all_rows, summary_text, flag_text, file_path)
print("Done with JSON")
elif output_type == "TXT":
file_path = f"{tmp_dir}/batch_output.txt"
save_to_txt(all_rows, summary_text, flag_text, file_path)
else:
return gr.update(visible=False) # invalid option
return gr.update(value=file_path, visible=True)
# save cost by checking the known outputs
# def check_known_output(accession):
# if not os.path.exists(KNOWN_OUTPUT_PATH):
# return None
# try:
# df = pd.read_excel(KNOWN_OUTPUT_PATH)
# match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
# if match:
# accession = match.group(0)
# matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
# if not matched.empty:
# return matched.iloc[0].to_dict() # Return the cached row
# except Exception as e:
# print(f"⚠️ Failed to load known samples: {e}")
# return None
# def check_known_output(accession):
# try:
# # βœ… Load credentials from Hugging Face secret
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# client = gspread.authorize(creds)
# # βœ… Open the known_samples sheet
# spreadsheet = client.open("known_samples") # Replace with your sheet name
# sheet = spreadsheet.sheet1
# # βœ… Read all rows
# data = sheet.get_all_values()
# if not data:
# return None
# df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row
# # βœ… Normalize accession pattern
# match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
# if match:
# accession = match.group(0)
# matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
# if not matched.empty:
# return matched.iloc[0].to_dict()
# except Exception as e:
# print(f"⚠️ Failed to load known samples from Google Sheets: {e}")
# return None
def check_known_output(accession):
print("inside check known output function")
try:
# βœ… Load credentials from Hugging Face secret
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
spreadsheet = client.open("known_samples")
sheet = spreadsheet.sheet1
data = sheet.get_all_values()
if not data:
print("⚠️ Google Sheet 'known_samples' is empty.")
return None
df = pd.DataFrame(data[1:], columns=data[0])
if "Sample ID" not in df.columns:
print("❌ Column 'Sample ID' not found in Google Sheet.")
return None
match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession)
if match:
accession = match.group(0)
matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)]
if not matched.empty:
#return matched.iloc[0].to_dict()
row = matched.iloc[0]
country = row.get("Predicted Country", "").strip().lower()
sample_type = row.get("Predicted Sample Type", "").strip().lower()
if country and country != "unknown" and sample_type and sample_type != "unknown":
return row.to_dict()
else:
print(f"⚠️ Accession {accession} found but country/sample_type is unknown or empty.")
return None
else:
print(f"πŸ” Accession {accession} not found in known_samples.")
return None
except Exception as e:
import traceback
print("❌ Exception occurred during check_known_output:")
traceback.print_exc()
return None
def hash_user_id(user_input):
return hashlib.sha256(user_input.encode()).hexdigest()
# βœ… Load and save usage count
# def load_user_usage():
# if not os.path.exists(USER_USAGE_TRACK_FILE):
# return {}
# try:
# with open(USER_USAGE_TRACK_FILE, "r") as f:
# content = f.read().strip()
# if not content:
# return {} # file is empty
# return json.loads(content)
# except (json.JSONDecodeError, ValueError):
# print("⚠️ Warning: user_usage.json is corrupted or invalid. Resetting.")
# return {} # fallback to empty dict
# def load_user_usage():
# try:
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# client = gspread.authorize(creds)
# sheet = client.open("user_usage_log").sheet1
# data = sheet.get_all_records() # Assumes columns: email, usage_count
# usage = {}
# for row in data:
# email = row.get("email", "").strip().lower()
# count = int(row.get("usage_count", 0))
# if email:
# usage[email] = count
# return usage
# except Exception as e:
# print(f"⚠️ Failed to load user usage from Google Sheets: {e}")
# return {}
# def load_user_usage():
# try:
# parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
# iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
# found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id)
# if not found:
# return {} # not found, start fresh
# #file_id = found[0]["id"]
# file_id = found
# content = pipeline.download_drive_file_content(file_id)
# return json.loads(content.strip()) if content.strip() else {}
# except Exception as e:
# print(f"⚠️ Failed to load user_usage_log.json from Google Drive: {e}")
# return {}
def load_user_usage():
try:
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
sheet = client.open("user_usage_log").sheet1
data = sheet.get_all_values()
print("data: ", data)
print("πŸ§ͺ Raw header row from sheet:", data[0])
print("πŸ§ͺ Character codes in each header:")
for h in data[0]:
print([ord(c) for c in h])
if not data or len(data) < 2:
print("⚠️ Sheet is empty or missing rows.")
return {}
headers = [h.strip().lower() for h in data[0]]
if "email" not in headers or "usage_count" not in headers:
print("❌ Header format incorrect. Must have 'email' and 'usage_count'.")
return {}
permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None
df = pd.DataFrame(data[1:], columns=headers)
usage = {}
permitted = {}
for _, row in df.iterrows():
email = row.get("email", "").strip().lower()
try:
#count = int(row.get("usage_count", 0))
try:
count = int(float(row.get("usage_count", 0)))
except Exception:
print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
count = 0
if email:
usage[email] = count
if permitted_index is not None:
try:
permitted_count = int(float(row.get("permitted_samples", 50)))
permitted[email] = permitted_count
except:
permitted[email] = 50
except ValueError:
print(f"⚠️ Invalid usage_count for {email}: {row.get('usage_count')}")
return usage, permitted
except Exception as e:
print(f"❌ Error in load_user_usage: {e}")
return {}, {}
# def save_user_usage(usage):
# with open(USER_USAGE_TRACK_FILE, "w") as f:
# json.dump(usage, f, indent=2)
# def save_user_usage(usage_dict):
# try:
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# client = gspread.authorize(creds)
# sheet = client.open("user_usage_log").sheet1
# sheet.clear() # clear old contents first
# # Write header + rows
# rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()]
# sheet.update(rows)
# except Exception as e:
# print(f"❌ Failed to save user usage to Google Sheets: {e}")
# def save_user_usage(usage_dict):
# try:
# parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier")
# iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id)
# import tempfile
# tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json")
# print("πŸ’Ύ Saving this usage dict:", usage_dict)
# with open(tmp_path, "w") as f:
# json.dump(usage_dict, f, indent=2)
# pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id)
# except Exception as e:
# print(f"❌ Failed to save user_usage_log.json to Google Drive: {e}")
# def save_user_usage(usage_dict):
# try:
# creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
# scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
# creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
# client = gspread.authorize(creds)
# spreadsheet = client.open("user_usage_log")
# sheet = spreadsheet.sheet1
# # Step 1: Convert new usage to DataFrame
# df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
# df_new["email"] = df_new["email"].str.strip().str.lower()
# # Step 2: Load existing data
# existing_data = sheet.get_all_values()
# print("πŸ§ͺ Sheet existing_data:", existing_data)
# # Try to load old data
# if existing_data and len(existing_data[0]) >= 1:
# df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
# # Fix missing columns
# if "email" not in df_old.columns:
# df_old["email"] = ""
# if "usage_count" not in df_old.columns:
# df_old["usage_count"] = 0
# df_old["email"] = df_old["email"].str.strip().str.lower()
# df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
# else:
# df_old = pd.DataFrame(columns=["email", "usage_count"])
# # Step 3: Merge
# df_combined = pd.concat([df_old, df_new], ignore_index=True)
# df_combined = df_combined.groupby("email", as_index=False).sum()
# # Step 4: Write back
# sheet.clear()
# sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist())
# print("βœ… Saved user usage to user_usage_log sheet.")
# except Exception as e:
# print(f"❌ Failed to save user usage to Google Sheets: {e}")
def save_user_usage(usage_dict):
try:
creds_dict = json.loads(os.environ["GCP_CREDS_JSON"])
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
client = gspread.authorize(creds)
spreadsheet = client.open("user_usage_log")
sheet = spreadsheet.sheet1
# Build new df
df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"])
df_new["email"] = df_new["email"].str.strip().str.lower()
df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int)
# Read existing data
existing_data = sheet.get_all_values()
if existing_data and len(existing_data[0]) >= 2:
df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0])
df_old["email"] = df_old["email"].str.strip().str.lower()
df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int)
else:
df_old = pd.DataFrame(columns=["email", "usage_count"])
# βœ… Overwrite specific emails only
df_old = df_old.set_index("email")
for email, count in usage_dict.items():
email = email.strip().lower()
df_old.loc[email, "usage_count"] = count
df_old = df_old.reset_index()
# Save
sheet.clear()
sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist())
print("βœ… Saved user usage to user_usage_log sheet.")
except Exception as e:
print(f"❌ Failed to save user usage to Google Sheets: {e}")
# def increment_usage(user_id, num_samples=1):
# usage = load_user_usage()
# if user_id not in usage:
# usage[user_id] = 0
# usage[user_id] += num_samples
# save_user_usage(usage)
# return usage[user_id]
# def increment_usage(email: str, count: int):
# usage = load_user_usage()
# email_key = email.strip().lower()
# usage[email_key] = usage.get(email_key, 0) + count
# save_user_usage(usage)
# return usage[email_key]
def increment_usage(email: str, count: int = 1):
usage, permitted = load_user_usage()
email_key = email.strip().lower()
#usage[email_key] = usage.get(email_key, 0) + count
current = usage.get(email_key, 0)
new_value = current + count
max_allowed = permitted.get(email_key) or 50
usage[email_key] = max(current, new_value) # βœ… Prevent overwrite with lower
print(f"πŸ§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}")
print("max allow is: ", max_allowed)
save_user_usage(usage)
return usage[email_key], max_allowed
# run the batch
def summarize_batch(file=None, raw_text="", resume_file=None, user_email="",
stop_flag=None, output_file_path=None,
limited_acc=50, yield_callback=None):
if user_email:
limited_acc += 10
accessions, error = extract_accessions_from_input(file, raw_text)
if error:
#return [], "", "", f"Error: {error}"
return [], f"Error: {error}", 0, "", ""
if resume_file:
accessions = get_incomplete_accessions(resume_file)
tmp_dir = tempfile.mkdtemp()
if not output_file_path:
if resume_file:
output_file_path = os.path.join(tmp_dir, resume_file)
else:
output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx")
all_rows = []
# all_summaries = []
# all_flags = []
progress_lines = []
warning = ""
if len(accessions) > limited_acc:
accessions = accessions[:limited_acc]
warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions"
for i, acc in enumerate(accessions):
if stop_flag and stop_flag.value:
line = f"πŸ›‘ Stopped at {acc} ({i+1}/{len(accessions)})"
progress_lines.append(line)
if yield_callback:
yield_callback(line)
print("πŸ›‘ User requested stop.")
break
print(f"[{i+1}/{len(accessions)}] Processing {acc}")
try:
# rows, summary, label, explain = summarize_results(acc)
rows = summarize_results(acc)
all_rows.extend(rows)
# all_summaries.append(f"**{acc}**\n{summary}")
# all_flags.append(f"**{acc}**\n### 🏺 Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}")
#save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path)
save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file))
line = f"βœ… Processed {acc} ({i+1}/{len(accessions)})"
progress_lines.append(line)
if yield_callback:
yield_callback(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
except Exception as e:
print(f"❌ Failed to process {acc}: {e}")
continue
#all_summaries.append(f"**{acc}**: Failed - {e}")
#progress_lines.append(f"βœ… Processed {acc} ({i+1}/{len(accessions)})")
limited_acc -= 1
"""for row in all_rows:
source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2)
if source_column.startswith("http"): # Check if the source is a URL
# Wrap it with HTML anchor tags to make it clickable
row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'"""
if not warning:
warning = f"You only have {limited_acc} left"
if user_email.strip():
user_hash = hash_user_id(user_email)
total_queries = increment_usage(user_hash, len(all_rows))
else:
total_queries = 0
yield_callback("βœ… Finished!")
# summary_text = "\n\n---\n\n".join(all_summaries)
# flag_text = "\n\n---\n\n".join(all_flags)
#return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False)
#return all_rows, gr.update(visible=True), gr.update(visible=False)
return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning