Spaces:
Running
Running
| import gradio as gr | |
| from collections import Counter | |
| import csv | |
| import os | |
| from functools import lru_cache | |
| #import app | |
| from mtdna_classifier import classify_sample_location | |
| import data_preprocess, model, pipeline | |
| import subprocess | |
| import json | |
| import pandas as pd | |
| import io | |
| import re | |
| import tempfile | |
| import gspread | |
| from oauth2client.service_account import ServiceAccountCredentials | |
| from io import StringIO | |
| import hashlib | |
| import threading | |
| # @lru_cache(maxsize=3600) | |
| # def classify_sample_location_cached(accession): | |
| # return classify_sample_location(accession) | |
| #@lru_cache(maxsize=3600) | |
| def pipeline_classify_sample_location_cached(accession,stop_flag=None, save_df=None): | |
| print("inside pipeline_classify_sample_location_cached, and [accession] is ", [accession]) | |
| print("len of save df: ", len(save_df)) | |
| return pipeline.pipeline_with_gemini([accession],stop_flag=stop_flag, save_df=save_df) | |
| # Count and suggest final location | |
| # def compute_final_suggested_location(rows): | |
| # candidates = [ | |
| # row.get("Predicted Location", "").strip() | |
| # for row in rows | |
| # if row.get("Predicted Location", "").strip().lower() not in ["", "sample id not found", "unknown"] | |
| # ] + [ | |
| # row.get("Inferred Region", "").strip() | |
| # for row in rows | |
| # if row.get("Inferred Region", "").strip().lower() not in ["", "sample id not found", "unknown"] | |
| # ] | |
| # if not candidates: | |
| # return Counter(), ("Unknown", 0) | |
| # # Step 1: Combine into one string and split using regex to handle commas, line breaks, etc. | |
| # tokens = [] | |
| # for item in candidates: | |
| # # Split by comma, whitespace, and newlines | |
| # parts = re.split(r'[\s,]+', item) | |
| # tokens.extend(parts) | |
| # # Step 2: Clean and normalize tokens | |
| # tokens = [word.strip() for word in tokens if word.strip().isalpha()] # Keep only alphabetic tokens | |
| # # Step 3: Count | |
| # counts = Counter(tokens) | |
| # # Step 4: Get most common | |
| # top_location, count = counts.most_common(1)[0] | |
| # return counts, (top_location, count) | |
| # Store feedback (with required fields) | |
| def store_feedback_to_google_sheets(accession, answer1, answer2, contact=""): | |
| if not answer1.strip() or not answer2.strip(): | |
| return "β οΈ Please answer both questions before submitting." | |
| try: | |
| # β Step: Load credentials from Hugging Face secret | |
| creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"] | |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # Connect to Google Sheet | |
| client = gspread.authorize(creds) | |
| sheet = client.open("feedback_mtdna").sheet1 # make sure sheet name matches | |
| # Append feedback | |
| sheet.append_row([accession, answer1, answer2, contact]) | |
| return "β Feedback submitted. Thank you!" | |
| except Exception as e: | |
| return f"β Error submitting feedback: {e}" | |
| import re | |
| ACCESSION_REGEX = re.compile(r'^[A-Z]{1,4}_?\d{6}(\.\d+)?$') | |
| def is_valid_accession(acc): | |
| return bool(ACCESSION_REGEX.match(acc)) | |
| # helper function to extract accessions | |
| def extract_accessions_from_input(file=None, raw_text=""): | |
| print(f"RAW TEXT RECEIVED: {raw_text}") | |
| accessions, invalid_accessions = [], [] | |
| seen = set() | |
| if file: | |
| try: | |
| if file.name.endswith(".csv"): | |
| df = pd.read_csv(file) | |
| elif file.name.endswith(".xlsx"): | |
| df = pd.read_excel(file) | |
| else: | |
| return [], "Unsupported file format. Please upload CSV or Excel." | |
| for acc in df.iloc[:, 0].dropna().astype(str).str.strip(): | |
| if acc not in seen: | |
| if is_valid_accession(acc): | |
| accessions.append(acc) | |
| seen.add(acc) | |
| else: | |
| invalid_accessions.append(acc) | |
| except Exception as e: | |
| return [],[], f"Failed to read file: {e}" | |
| if raw_text: | |
| try: | |
| text_ids = [s.strip() for s in re.split(r"[\n,;\t]", raw_text) if s.strip()] | |
| for acc in text_ids: | |
| if acc not in seen: | |
| if is_valid_accession(acc): | |
| accessions.append(acc) | |
| seen.add(acc) | |
| else: | |
| invalid_accessions.append(acc) | |
| except Exception as e: | |
| return [],[], f"Failed to read file: {e}" | |
| return list(accessions), list(invalid_accessions), None | |
| # β Add a new helper to backend: `filter_unprocessed_accessions()` | |
| def get_incomplete_accessions(file_path): | |
| df = pd.read_excel(file_path) | |
| incomplete_accessions = [] | |
| for _, row in df.iterrows(): | |
| sample_id = str(row.get("Sample ID", "")).strip() | |
| # Skip if no sample ID | |
| if not sample_id: | |
| continue | |
| # Drop the Sample ID and check if the rest is empty | |
| other_cols = row.drop(labels=["Sample ID"], errors="ignore") | |
| if other_cols.isna().all() or (other_cols.astype(str).str.strip() == "").all(): | |
| # Extract the accession number from the sample ID using regex | |
| match = re.search(r"\b[A-Z]{2,4}\d{4,}", sample_id) | |
| if match: | |
| incomplete_accessions.append(match.group(0)) | |
| print(len(incomplete_accessions)) | |
| return incomplete_accessions | |
| # GOOGLE_SHEET_NAME = "known_samples" | |
| # USAGE_DRIVE_FILENAME = "user_usage_log.json" | |
| def summarize_results(accession, stop_flag=None): | |
| # Early bail | |
| if stop_flag is not None and stop_flag.value: | |
| print(f"π Skipping {accession} before starting.") | |
| return [] | |
| # try cache first | |
| cached = check_known_output(accession) | |
| if cached: | |
| print(f"β Using cached result for {accession}") | |
| return [[ | |
| cached["Sample ID"] or "unknown", | |
| cached["Predicted Country"] or "unknown", | |
| cached["Country Explanation"] or "unknown", | |
| cached["Predicted Sample Type"] or "unknown", | |
| cached["Sample Type Explanation"] or "unknown", | |
| cached["Sources"] or "No Links", | |
| cached["Time cost"] | |
| ]] | |
| # only run when nothing in the cache | |
| try: | |
| print("try gemini pipeline: ",accession) | |
| # β Load credentials from Hugging Face secret | |
| creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| client = gspread.authorize(creds) | |
| spreadsheet = client.open("known_samples") | |
| sheet = spreadsheet.sheet1 | |
| data = sheet.get_all_values() | |
| if not data: | |
| print("β οΈ Google Sheet 'known_samples' is empty.") | |
| return None | |
| save_df = pd.DataFrame(data[1:], columns=data[0]) | |
| print("before pipeline, len of save df: ", len(save_df)) | |
| outputs = pipeline_classify_sample_location_cached(accession, stop_flag, save_df) | |
| if stop_flag is not None and stop_flag.value: | |
| print(f"π Skipped {accession} mid-pipeline.") | |
| return [] | |
| # outputs = {'KU131308': {'isolate':'BRU18', | |
| # 'country': {'brunei': ['ncbi', | |
| # 'rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, | |
| # 'sample_type': {'modern': | |
| # ['rag_llm-The text mentions "BRU18 Brunei Borneo" in a table listing various samples, and it is not described as ancient or archaeological.']}, | |
| # 'query_cost': 9.754999999999999e-05, | |
| # 'time_cost': '24.776 seconds', | |
| # 'source': ['https://doi.org/10.1007/s00439-015-1620-z', | |
| # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM1_ESM.pdf', | |
| # 'https://static-content.springer.com/esm/art%3A10.1007%2Fs00439-015-1620-z/MediaObjects/439_2015_1620_MOESM2_ESM.xls']}} | |
| except Exception as e: | |
| return []#, f"Error: {e}", f"Error: {e}", f"Error: {e}" | |
| if accession not in outputs: | |
| print("no accession in output ", accession) | |
| return []#, "Accession not found in results.", "Accession not found in results.", "Accession not found in results." | |
| row_score = [] | |
| rows = [] | |
| save_rows = [] | |
| for key in outputs: | |
| pred_country, pred_sample, country_explanation, sample_explanation = "unknown","unknown","unknown","unknown" | |
| for section, results in outputs[key].items(): | |
| if section == "country" or section =="sample_type": | |
| pred_output = []#"\n".join(list(results.keys())) | |
| output_explanation = "" | |
| for result, content in results.items(): | |
| if len(result) == 0: result = "unknown" | |
| if len(content) == 0: output_explanation = "unknown" | |
| else: | |
| output_explanation += 'Method: ' + "\nMethod: ".join(content) + "\n" | |
| pred_output.append(result) | |
| pred_output = "\n".join(pred_output) | |
| if section == "country": | |
| pred_country, country_explanation = pred_output, output_explanation | |
| elif section == "sample_type": | |
| pred_sample, sample_explanation = pred_output, output_explanation | |
| if outputs[key]["isolate"].lower()!="unknown": | |
| label = key + "(Isolate: " + outputs[key]["isolate"] + ")" | |
| else: label = key | |
| if len(outputs[key]["source"]) == 0: outputs[key]["source"] = ["No Links"] | |
| row = { | |
| "Sample ID": label or "unknown", | |
| "Predicted Country": pred_country or "unknown", | |
| "Country Explanation": country_explanation or "unknown", | |
| "Predicted Sample Type":pred_sample or "unknown", | |
| "Sample Type Explanation":sample_explanation or "unknown", | |
| "Sources": "\n".join(outputs[key]["source"]) or "No Links", | |
| "Time cost": outputs[key]["time_cost"] | |
| } | |
| #row_score.append(row) | |
| rows.append(list(row.values())) | |
| save_row = { | |
| "Sample ID": label or "unknown", | |
| "Predicted Country": pred_country or "unknown", | |
| "Country Explanation": country_explanation or "unknown", | |
| "Predicted Sample Type":pred_sample or "unknown", | |
| "Sample Type Explanation":sample_explanation or "unknown", | |
| "Sources": "\n".join(outputs[key]["source"]) or "No Links", | |
| "Query_cost": outputs[key]["query_cost"] or "", | |
| "Time cost": outputs[key]["time_cost"] or "", | |
| "file_chunk":outputs[key]["file_chunk"] or "", | |
| "file_all_output":outputs[key]["file_all_output"] or "" | |
| } | |
| #row_score.append(row) | |
| save_rows.append(list(save_row.values())) | |
| # #location_counts, (final_location, count) = compute_final_suggested_location(row_score) | |
| # summary_lines = [f"### π§ Location Summary:\n"] | |
| # summary_lines += [f"- **{loc}**: {cnt} times" for loc, cnt in location_counts.items()] | |
| # summary_lines.append(f"\n**Final Suggested Location:** πΊοΈ **{final_location}** (mentioned {count} times)") | |
| # summary = "\n".join(summary_lines) | |
| # save the new running sample to known excel file | |
| # try: | |
| # df_new = pd.DataFrame(save_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Query_cost","Time cost"]) | |
| # if os.path.exists(KNOWN_OUTPUT_PATH): | |
| # df_old = pd.read_excel(KNOWN_OUTPUT_PATH) | |
| # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") | |
| # else: | |
| # df_combined = df_new | |
| # df_combined.to_excel(KNOWN_OUTPUT_PATH, index=False) | |
| # except Exception as e: | |
| # print(f"β οΈ Failed to save known output: {e}") | |
| # try: | |
| # df_new = pd.DataFrame(save_rows, columns=[ | |
| # "Sample ID", "Predicted Country", "Country Explanation", | |
| # "Predicted Sample Type", "Sample Type Explanation", | |
| # "Sources", "Query_cost", "Time cost" | |
| # ]) | |
| # # β Google Sheets API setup | |
| # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # client = gspread.authorize(creds) | |
| # # β Open the known_samples sheet | |
| # spreadsheet = client.open("known_samples") # Replace with your sheet name | |
| # sheet = spreadsheet.sheet1 | |
| # # β Read old data | |
| # existing_data = sheet.get_all_values() | |
| # if existing_data: | |
| # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
| # else: | |
| # df_old = pd.DataFrame(columns=df_new.columns) | |
| # # β Combine and remove duplicates | |
| # df_combined = pd.concat([df_old, df_new], ignore_index=True).drop_duplicates(subset="Sample ID") | |
| # # β Clear and write back | |
| # sheet.clear() | |
| # sheet.update([df_combined.columns.values.tolist()] + df_combined.values.tolist()) | |
| # except Exception as e: | |
| # print(f"β οΈ Failed to save known output to Google Sheets: {e}") | |
| try: | |
| # Prepare as DataFrame | |
| df_new = pd.DataFrame(save_rows, columns=[ | |
| "Sample ID", "Predicted Country", "Country Explanation", | |
| "Predicted Sample Type", "Sample Type Explanation", | |
| "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output" | |
| ]) | |
| # β Setup Google Sheets | |
| creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| client = gspread.authorize(creds) | |
| spreadsheet = client.open("known_samples") | |
| sheet = spreadsheet.sheet1 | |
| # β Read existing data | |
| existing_data = sheet.get_all_values() | |
| if existing_data: | |
| df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
| else: | |
| df_old = pd.DataFrame(columns=[ | |
| "Sample ID", "Actual_country", "Actual_sample_type", "Country Explanation", | |
| "Match_country", "Match_sample_type", "Predicted Country", "Predicted Sample Type", | |
| "Query_cost", "Sample Type Explanation", "Sources", "Time cost", "file_chunk", "file_all_output" | |
| ]) | |
| # β Index by Sample ID | |
| df_old.set_index("Sample ID", inplace=True) | |
| df_new.set_index("Sample ID", inplace=True) | |
| # β Update only matching fields | |
| update_columns = [ | |
| "Predicted Country", "Predicted Sample Type", "Country Explanation", | |
| "Sample Type Explanation", "Sources", "Query_cost", "Time cost", "file_chunk", "file_all_output" | |
| ] | |
| for idx, row in df_new.iterrows(): | |
| if idx not in df_old.index: | |
| df_old.loc[idx] = "" # new row, fill empty first | |
| for col in update_columns: | |
| if pd.notna(row[col]) and row[col] != "": | |
| df_old.at[idx, col] = row[col] | |
| # β Reset and write back | |
| df_old.reset_index(inplace=True) | |
| sheet.clear() | |
| sheet.update([df_old.columns.values.tolist()] + df_old.values.tolist()) | |
| print("β Match results saved to known_samples.") | |
| except Exception as e: | |
| print(f"β Failed to update known_samples: {e}") | |
| return rows#, summary, labelAncient_Modern, explain_label | |
| # save the batch input in excel file | |
| # def save_to_excel(all_rows, summary_text, flag_text, filename): | |
| # with pd.ExcelWriter(filename) as writer: | |
| # # Save table | |
| # df_new = pd.DataFrame(all_rows, columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) | |
| # df.to_excel(writer, sheet_name="Detailed Results", index=False) | |
| # try: | |
| # df_old = pd.read_excel(filename) | |
| # except: | |
| # df_old = pd.DataFrame([[]], columns=["Sample ID", "Predicted Country", "Country Explanation", "Predicted Sample Type", "Sample Type Explanation", "Sources", "Time cost"]) | |
| # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") | |
| # # if os.path.exists(filename): | |
| # # df_old = pd.read_excel(filename) | |
| # # df_combined = pd.concat([df_old, df_new]).drop_duplicates(subset="Sample ID") | |
| # # else: | |
| # # df_combined = df_new | |
| # df_combined.to_excel(filename, index=False) | |
| # # # Save summary | |
| # # summary_df = pd.DataFrame({"Summary": [summary_text]}) | |
| # # summary_df.to_excel(writer, sheet_name="Summary", index=False) | |
| # # # Save flag | |
| # # flag_df = pd.DataFrame({"Flag": [flag_text]}) | |
| # # flag_df.to_excel(writer, sheet_name="Ancient_Modern_Flag", index=False) | |
| # def save_to_excel(all_rows, summary_text, flag_text, filename): | |
| # df_new = pd.DataFrame(all_rows, columns=[ | |
| # "Sample ID", "Predicted Country", "Country Explanation", | |
| # "Predicted Sample Type", "Sample Type Explanation", | |
| # "Sources", "Time cost" | |
| # ]) | |
| # try: | |
| # if os.path.exists(filename): | |
| # df_old = pd.read_excel(filename) | |
| # else: | |
| # df_old = pd.DataFrame(columns=df_new.columns) | |
| # except Exception as e: | |
| # print(f"β οΈ Warning reading old Excel file: {e}") | |
| # df_old = pd.DataFrame(columns=df_new.columns) | |
| # #df_combined = pd.concat([df_new, df_old], ignore_index=True).drop_duplicates(subset="Sample ID", keep="first") | |
| # df_old.set_index("Sample ID", inplace=True) | |
| # df_new.set_index("Sample ID", inplace=True) | |
| # df_old.update(df_new) # <-- update matching rows in df_old with df_new content | |
| # df_combined = df_old.reset_index() | |
| # try: | |
| # df_combined.to_excel(filename, index=False) | |
| # except Exception as e: | |
| # print(f"β Failed to write Excel file {filename}: {e}") | |
| def save_to_excel(all_rows, summary_text, flag_text, filename, is_resume=False): | |
| df_new = pd.DataFrame(all_rows, columns=[ | |
| "Sample ID", "Predicted Country", "Country Explanation", | |
| "Predicted Sample Type", "Sample Type Explanation", | |
| "Sources", "Time cost" | |
| ]) | |
| if is_resume and os.path.exists(filename): | |
| try: | |
| df_old = pd.read_excel(filename) | |
| except Exception as e: | |
| print(f"β οΈ Warning reading old Excel file: {e}") | |
| df_old = pd.DataFrame(columns=df_new.columns) | |
| # Set index and update existing rows | |
| df_old.set_index("Sample ID", inplace=True) | |
| df_new.set_index("Sample ID", inplace=True) | |
| df_old.update(df_new) | |
| df_combined = df_old.reset_index() | |
| else: | |
| # If not resuming or file doesn't exist, just use new rows | |
| df_combined = df_new | |
| try: | |
| df_combined.to_excel(filename, index=False) | |
| except Exception as e: | |
| print(f"β Failed to write Excel file {filename}: {e}") | |
| # save the batch input in JSON file | |
| def save_to_json(all_rows, summary_text, flag_text, filename): | |
| output_dict = { | |
| "Detailed_Results": all_rows#, # <-- make sure this is a plain list, not a DataFrame | |
| # "Summary_Text": summary_text, | |
| # "Ancient_Modern_Flag": flag_text | |
| } | |
| # If all_rows is a DataFrame, convert it | |
| if isinstance(all_rows, pd.DataFrame): | |
| output_dict["Detailed_Results"] = all_rows.to_dict(orient="records") | |
| with open(filename, "w") as external_file: | |
| json.dump(output_dict, external_file, indent=2) | |
| # save the batch input in Text file | |
| def save_to_txt(all_rows, summary_text, flag_text, filename): | |
| if isinstance(all_rows, pd.DataFrame): | |
| detailed_results = all_rows.to_dict(orient="records") | |
| output = "" | |
| #output += ",".join(list(detailed_results[0].keys())) + "\n\n" | |
| output += ",".join([str(k) for k in detailed_results[0].keys()]) + "\n\n" | |
| for r in detailed_results: | |
| output += ",".join([str(v) for v in r.values()]) + "\n\n" | |
| with open(filename, "w") as f: | |
| f.write("=== Detailed Results ===\n") | |
| f.write(output + "\n") | |
| # f.write("\n=== Summary ===\n") | |
| # f.write(summary_text + "\n") | |
| # f.write("\n=== Ancient/Modern Flag ===\n") | |
| # f.write(flag_text + "\n") | |
| def save_batch_output(all_rows, output_type, summary_text=None, flag_text=None): | |
| tmp_dir = tempfile.mkdtemp() | |
| #html_table = all_rows.value # assuming this is stored somewhere | |
| # Parse back to DataFrame | |
| #all_rows = pd.read_html(all_rows)[0] # [0] because read_html returns a list | |
| all_rows = pd.read_html(StringIO(all_rows))[0] | |
| print(all_rows) | |
| if output_type == "Excel": | |
| file_path = f"{tmp_dir}/batch_output.xlsx" | |
| save_to_excel(all_rows, summary_text, flag_text, file_path) | |
| elif output_type == "JSON": | |
| file_path = f"{tmp_dir}/batch_output.json" | |
| save_to_json(all_rows, summary_text, flag_text, file_path) | |
| print("Done with JSON") | |
| elif output_type == "TXT": | |
| file_path = f"{tmp_dir}/batch_output.txt" | |
| save_to_txt(all_rows, summary_text, flag_text, file_path) | |
| else: | |
| return gr.update(visible=False) # invalid option | |
| return gr.update(value=file_path, visible=True) | |
| # save cost by checking the known outputs | |
| # def check_known_output(accession): | |
| # if not os.path.exists(KNOWN_OUTPUT_PATH): | |
| # return None | |
| # try: | |
| # df = pd.read_excel(KNOWN_OUTPUT_PATH) | |
| # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
| # if match: | |
| # accession = match.group(0) | |
| # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] | |
| # if not matched.empty: | |
| # return matched.iloc[0].to_dict() # Return the cached row | |
| # except Exception as e: | |
| # print(f"β οΈ Failed to load known samples: {e}") | |
| # return None | |
| # def check_known_output(accession): | |
| # try: | |
| # # β Load credentials from Hugging Face secret | |
| # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # client = gspread.authorize(creds) | |
| # # β Open the known_samples sheet | |
| # spreadsheet = client.open("known_samples") # Replace with your sheet name | |
| # sheet = spreadsheet.sheet1 | |
| # # β Read all rows | |
| # data = sheet.get_all_values() | |
| # if not data: | |
| # return None | |
| # df = pd.DataFrame(data[1:], columns=data[0]) # Skip header row | |
| # # β Normalize accession pattern | |
| # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
| # if match: | |
| # accession = match.group(0) | |
| # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] | |
| # if not matched.empty: | |
| # return matched.iloc[0].to_dict() | |
| # except Exception as e: | |
| # print(f"β οΈ Failed to load known samples from Google Sheets: {e}") | |
| # return None | |
| # def check_known_output(accession): | |
| # print("inside check known output function") | |
| # try: | |
| # # β Load credentials from Hugging Face secret | |
| # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # client = gspread.authorize(creds) | |
| # spreadsheet = client.open("known_samples") | |
| # sheet = spreadsheet.sheet1 | |
| # data = sheet.get_all_values() | |
| # if not data: | |
| # print("β οΈ Google Sheet 'known_samples' is empty.") | |
| # return None | |
| # df = pd.DataFrame(data[1:], columns=data[0]) | |
| # if "Sample ID" not in df.columns: | |
| # print("β Column 'Sample ID' not found in Google Sheet.") | |
| # return None | |
| # match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
| # if match: | |
| # accession = match.group(0) | |
| # matched = df[df["Sample ID"].str.contains(accession, case=False, na=False)] | |
| # if not matched.empty: | |
| # #return matched.iloc[0].to_dict() | |
| # row = matched.iloc[0] | |
| # country = row.get("Predicted Country", "").strip().lower() | |
| # sample_type = row.get("Predicted Sample Type", "").strip().lower() | |
| # if country and country != "unknown" and sample_type and sample_type != "unknown": | |
| # return row.to_dict() | |
| # else: | |
| # print(f"β οΈ Accession {accession} found but country/sample_type is unknown or empty.") | |
| # return None | |
| # else: | |
| # print(f"π Accession {accession} not found in known_samples.") | |
| # return None | |
| # except Exception as e: | |
| # import traceback | |
| # print("β Exception occurred during check_known_output:") | |
| # traceback.print_exc() | |
| # return None | |
| import os | |
| import re | |
| import json | |
| import time | |
| import gspread | |
| import pandas as pd | |
| from oauth2client.service_account import ServiceAccountCredentials | |
| from gspread.exceptions import APIError | |
| # --- Global cache --- | |
| _known_samples_cache = None | |
| def load_known_samples(): | |
| """Load the Google Sheet 'known_samples' into a Pandas DataFrame and cache it.""" | |
| global _known_samples_cache | |
| try: | |
| creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| scope = [ | |
| 'https://spreadsheets.google.com/feeds', | |
| 'https://www.googleapis.com/auth/drive' | |
| ] | |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| client = gspread.authorize(creds) | |
| sheet = client.open("known_samples").sheet1 | |
| data = sheet.get_all_values() | |
| if not data: | |
| print("β οΈ Google Sheet 'known_samples' is empty.") | |
| _known_samples_cache = pd.DataFrame() | |
| else: | |
| _known_samples_cache = pd.DataFrame(data[1:], columns=data[0]) | |
| print(f"β Cached {_known_samples_cache.shape[0]} rows from known_samples") | |
| except APIError as e: | |
| print(f"β APIError while loading known_samples: {e}") | |
| _known_samples_cache = pd.DataFrame() | |
| except Exception as e: | |
| import traceback | |
| print("β Exception occurred while loading known_samples:") | |
| traceback.print_exc() | |
| _known_samples_cache = pd.DataFrame() | |
| def check_known_output(accession): | |
| """Check if an accession exists in the cached 'known_samples' sheet.""" | |
| global _known_samples_cache | |
| print("inside check known output function") | |
| try: | |
| # Load cache if not already loaded | |
| if _known_samples_cache is None: | |
| load_known_samples() | |
| if _known_samples_cache.empty: | |
| print("β οΈ No cached data available.") | |
| return None | |
| # Extract proper accession format (e.g. AB12345) | |
| match = re.search(r"\b[A-Z]{2,4}\d{4,}", accession) | |
| if match: | |
| accession = match.group(0) | |
| matched = _known_samples_cache[ | |
| _known_samples_cache["Sample ID"].str.contains(accession, case=False, na=False) | |
| ] | |
| if not matched.empty: | |
| row = matched.iloc[0] | |
| country = row.get("Predicted Country", "").strip().lower() | |
| sample_type = row.get("Predicted Sample Type", "").strip().lower() | |
| if country and country != "unknown" and sample_type and sample_type != "unknown": | |
| print(f"π― Found {accession} in cache") | |
| return row.to_dict() | |
| else: | |
| print(f"β οΈ Accession {accession} found but country/sample_type unknown or empty.") | |
| return None | |
| else: | |
| print(f"π Accession {accession} not found in cache.") | |
| return None | |
| except Exception as e: | |
| import traceback | |
| print("β Exception occurred during check_known_output:") | |
| traceback.print_exc() | |
| return None | |
| def hash_user_id(user_input): | |
| return hashlib.sha256(user_input.encode()).hexdigest() | |
| # β Load and save usage count | |
| # def load_user_usage(): | |
| # if not os.path.exists(USER_USAGE_TRACK_FILE): | |
| # return {} | |
| # try: | |
| # with open(USER_USAGE_TRACK_FILE, "r") as f: | |
| # content = f.read().strip() | |
| # if not content: | |
| # return {} # file is empty | |
| # return json.loads(content) | |
| # except (json.JSONDecodeError, ValueError): | |
| # print("β οΈ Warning: user_usage.json is corrupted or invalid. Resetting.") | |
| # return {} # fallback to empty dict | |
| # def load_user_usage(): | |
| # try: | |
| # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # client = gspread.authorize(creds) | |
| # sheet = client.open("user_usage_log").sheet1 | |
| # data = sheet.get_all_records() # Assumes columns: email, usage_count | |
| # usage = {} | |
| # for row in data: | |
| # email = row.get("email", "").strip().lower() | |
| # count = int(row.get("usage_count", 0)) | |
| # if email: | |
| # usage[email] = count | |
| # return usage | |
| # except Exception as e: | |
| # print(f"β οΈ Failed to load user usage from Google Sheets: {e}") | |
| # return {} | |
| # def load_user_usage(): | |
| # try: | |
| # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier") | |
| # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id) | |
| # found = pipeline.find_drive_file("user_usage_log.json", parent_id=iterate3_id) | |
| # if not found: | |
| # return {} # not found, start fresh | |
| # #file_id = found[0]["id"] | |
| # file_id = found | |
| # content = pipeline.download_drive_file_content(file_id) | |
| # return json.loads(content.strip()) if content.strip() else {} | |
| # except Exception as e: | |
| # print(f"β οΈ Failed to load user_usage_log.json from Google Drive: {e}") | |
| # return {} | |
| def load_user_usage(): | |
| try: | |
| creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| client = gspread.authorize(creds) | |
| sheet = client.open("user_usage_log").sheet1 | |
| data = sheet.get_all_values() | |
| print("data: ", data) | |
| print("π§ͺ Raw header row from sheet:", data[0]) | |
| print("π§ͺ Character codes in each header:") | |
| for h in data[0]: | |
| print([ord(c) for c in h]) | |
| if not data or len(data) < 2: | |
| print("β οΈ Sheet is empty or missing rows.") | |
| return {} | |
| headers = [h.strip().lower() for h in data[0]] | |
| if "email" not in headers or "usage_count" not in headers: | |
| print("β Header format incorrect. Must have 'email' and 'usage_count'.") | |
| return {} | |
| permitted_index = headers.index("permitted_samples") if "permitted_samples" in headers else None | |
| df = pd.DataFrame(data[1:], columns=headers) | |
| usage = {} | |
| permitted = {} | |
| for _, row in df.iterrows(): | |
| email = row.get("email", "").strip().lower() | |
| try: | |
| #count = int(row.get("usage_count", 0)) | |
| try: | |
| count = int(float(row.get("usage_count", 0))) | |
| except Exception: | |
| print(f"β οΈ Invalid usage_count for {email}: {row.get('usage_count')}") | |
| count = 0 | |
| if email: | |
| usage[email] = count | |
| if permitted_index is not None: | |
| try: | |
| permitted_count = int(float(row.get("permitted_samples", 50))) | |
| permitted[email] = permitted_count | |
| except: | |
| permitted[email] = 50 | |
| except ValueError: | |
| print(f"β οΈ Invalid usage_count for {email}: {row.get('usage_count')}") | |
| return usage, permitted | |
| except Exception as e: | |
| print(f"β Error in load_user_usage: {e}") | |
| return {}, {} | |
| # def save_user_usage(usage): | |
| # with open(USER_USAGE_TRACK_FILE, "w") as f: | |
| # json.dump(usage, f, indent=2) | |
| # def save_user_usage(usage_dict): | |
| # try: | |
| # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # client = gspread.authorize(creds) | |
| # sheet = client.open("user_usage_log").sheet1 | |
| # sheet.clear() # clear old contents first | |
| # # Write header + rows | |
| # rows = [["email", "usage_count"]] + [[email, count] for email, count in usage_dict.items()] | |
| # sheet.update(rows) | |
| # except Exception as e: | |
| # print(f"β Failed to save user usage to Google Sheets: {e}") | |
| # def save_user_usage(usage_dict): | |
| # try: | |
| # parent_id = pipeline.get_or_create_drive_folder("mtDNA-Location-Classifier") | |
| # iterate3_id = pipeline.get_or_create_drive_folder("iterate3", parent_id=parent_id) | |
| # import tempfile | |
| # tmp_path = os.path.join(tempfile.gettempdir(), "user_usage_log.json") | |
| # print("πΎ Saving this usage dict:", usage_dict) | |
| # with open(tmp_path, "w") as f: | |
| # json.dump(usage_dict, f, indent=2) | |
| # pipeline.upload_file_to_drive(tmp_path, "user_usage_log.json", iterate3_id) | |
| # except Exception as e: | |
| # print(f"β Failed to save user_usage_log.json to Google Drive: {e}") | |
| # def save_user_usage(usage_dict): | |
| # try: | |
| # creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| # scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| # creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| # client = gspread.authorize(creds) | |
| # spreadsheet = client.open("user_usage_log") | |
| # sheet = spreadsheet.sheet1 | |
| # # Step 1: Convert new usage to DataFrame | |
| # df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"]) | |
| # df_new["email"] = df_new["email"].str.strip().str.lower() | |
| # # Step 2: Load existing data | |
| # existing_data = sheet.get_all_values() | |
| # print("π§ͺ Sheet existing_data:", existing_data) | |
| # # Try to load old data | |
| # if existing_data and len(existing_data[0]) >= 1: | |
| # df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
| # # Fix missing columns | |
| # if "email" not in df_old.columns: | |
| # df_old["email"] = "" | |
| # if "usage_count" not in df_old.columns: | |
| # df_old["usage_count"] = 0 | |
| # df_old["email"] = df_old["email"].str.strip().str.lower() | |
| # df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int) | |
| # else: | |
| # df_old = pd.DataFrame(columns=["email", "usage_count"]) | |
| # # Step 3: Merge | |
| # df_combined = pd.concat([df_old, df_new], ignore_index=True) | |
| # df_combined = df_combined.groupby("email", as_index=False).sum() | |
| # # Step 4: Write back | |
| # sheet.clear() | |
| # sheet.update([df_combined.columns.tolist()] + df_combined.astype(str).values.tolist()) | |
| # print("β Saved user usage to user_usage_log sheet.") | |
| # except Exception as e: | |
| # print(f"β Failed to save user usage to Google Sheets: {e}") | |
| def save_user_usage(usage_dict): | |
| try: | |
| creds_dict = json.loads(os.environ["GCP_CREDS_JSON"]) | |
| scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
| creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope) | |
| client = gspread.authorize(creds) | |
| spreadsheet = client.open("user_usage_log") | |
| sheet = spreadsheet.sheet1 | |
| # Build new df | |
| df_new = pd.DataFrame(list(usage_dict.items()), columns=["email", "usage_count"]) | |
| df_new["email"] = df_new["email"].str.strip().str.lower() | |
| df_new["usage_count"] = pd.to_numeric(df_new["usage_count"], errors="coerce").fillna(0).astype(int) | |
| # Read existing data | |
| existing_data = sheet.get_all_values() | |
| if existing_data and len(existing_data[0]) >= 2: | |
| df_old = pd.DataFrame(existing_data[1:], columns=existing_data[0]) | |
| df_old["email"] = df_old["email"].str.strip().str.lower() | |
| df_old["usage_count"] = pd.to_numeric(df_old["usage_count"], errors="coerce").fillna(0).astype(int) | |
| else: | |
| df_old = pd.DataFrame(columns=["email", "usage_count"]) | |
| # β Overwrite specific emails only | |
| df_old = df_old.set_index("email") | |
| for email, count in usage_dict.items(): | |
| email = email.strip().lower() | |
| df_old.loc[email, "usage_count"] = count | |
| df_old = df_old.reset_index() | |
| # Save | |
| sheet.clear() | |
| sheet.update([df_old.columns.tolist()] + df_old.astype(str).values.tolist()) | |
| print("β Saved user usage to user_usage_log sheet.") | |
| except Exception as e: | |
| print(f"β Failed to save user usage to Google Sheets: {e}") | |
| # def increment_usage(user_id, num_samples=1): | |
| # usage = load_user_usage() | |
| # if user_id not in usage: | |
| # usage[user_id] = 0 | |
| # usage[user_id] += num_samples | |
| # save_user_usage(usage) | |
| # return usage[user_id] | |
| # def increment_usage(email: str, count: int): | |
| # usage = load_user_usage() | |
| # email_key = email.strip().lower() | |
| # usage[email_key] = usage.get(email_key, 0) + count | |
| # save_user_usage(usage) | |
| # return usage[email_key] | |
| def increment_usage(email: str, count: int = 1): | |
| usage, permitted = load_user_usage() | |
| email_key = email.strip().lower() | |
| #usage[email_key] = usage.get(email_key, 0) + count | |
| current = usage.get(email_key, 0) | |
| new_value = current + count | |
| max_allowed = permitted.get(email_key) or 50 | |
| usage[email_key] = max(current, new_value) # β Prevent overwrite with lower | |
| print(f"π§ͺ increment_usage saving: {email_key=} {current=} + {count=} => {usage[email_key]=}") | |
| print("max allow is: ", max_allowed) | |
| save_user_usage(usage) | |
| return usage[email_key], max_allowed | |
| # run the batch | |
| def summarize_batch(file=None, raw_text="", resume_file=None, user_email="", | |
| stop_flag=None, output_file_path=None, | |
| limited_acc=50, yield_callback=None): | |
| if user_email: | |
| limited_acc += 10 | |
| accessions, error = extract_accessions_from_input(file, raw_text) | |
| if error: | |
| #return [], "", "", f"Error: {error}" | |
| return [], f"Error: {error}", 0, "", "" | |
| if resume_file: | |
| accessions = get_incomplete_accessions(resume_file) | |
| tmp_dir = tempfile.mkdtemp() | |
| if not output_file_path: | |
| if resume_file: | |
| output_file_path = os.path.join(tmp_dir, resume_file) | |
| else: | |
| output_file_path = os.path.join(tmp_dir, "batch_output_live.xlsx") | |
| all_rows = [] | |
| # all_summaries = [] | |
| # all_flags = [] | |
| progress_lines = [] | |
| warning = "" | |
| if len(accessions) > limited_acc: | |
| accessions = accessions[:limited_acc] | |
| warning = f"Your number of accessions is more than the {limited_acc}, only handle first {limited_acc} accessions" | |
| for i, acc in enumerate(accessions): | |
| if stop_flag and stop_flag.value: | |
| line = f"π Stopped at {acc} ({i+1}/{len(accessions)})" | |
| progress_lines.append(line) | |
| if yield_callback: | |
| yield_callback(line) | |
| print("π User requested stop.") | |
| break | |
| print(f"[{i+1}/{len(accessions)}] Processing {acc}") | |
| try: | |
| # rows, summary, label, explain = summarize_results(acc) | |
| rows = summarize_results(acc) | |
| all_rows.extend(rows) | |
| # all_summaries.append(f"**{acc}**\n{summary}") | |
| # all_flags.append(f"**{acc}**\n### πΊ Ancient/Modern Flag\n**{label}**\n\n_Explanation:_ {explain}") | |
| #save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path) | |
| save_to_excel(all_rows, summary_text="", flag_text="", filename=output_file_path, is_resume=bool(resume_file)) | |
| line = f"β Processed {acc} ({i+1}/{len(accessions)})" | |
| progress_lines.append(line) | |
| if yield_callback: | |
| yield_callback(f"β Processed {acc} ({i+1}/{len(accessions)})") | |
| except Exception as e: | |
| print(f"β Failed to process {acc}: {e}") | |
| continue | |
| #all_summaries.append(f"**{acc}**: Failed - {e}") | |
| #progress_lines.append(f"β Processed {acc} ({i+1}/{len(accessions)})") | |
| limited_acc -= 1 | |
| """for row in all_rows: | |
| source_column = row[2] # Assuming the "Source" is in the 3rd column (index 2) | |
| if source_column.startswith("http"): # Check if the source is a URL | |
| # Wrap it with HTML anchor tags to make it clickable | |
| row[2] = f'<a href="{source_column}" target="_blank" style="color: blue; text-decoration: underline;">{source_column}</a>'""" | |
| if not warning: | |
| warning = f"You only have {limited_acc} left" | |
| if user_email.strip(): | |
| user_hash = hash_user_id(user_email) | |
| total_queries = increment_usage(user_hash, len(all_rows)) | |
| else: | |
| total_queries = 0 | |
| yield_callback("β Finished!") | |
| # summary_text = "\n\n---\n\n".join(all_summaries) | |
| # flag_text = "\n\n---\n\n".join(all_flags) | |
| #return all_rows, summary_text, flag_text, gr.update(visible=True), gr.update(visible=False) | |
| #return all_rows, gr.update(visible=True), gr.update(visible=False) | |
| return all_rows, output_file_path, total_queries, "\n".join(progress_lines), warning |