import gradio as gr import torch import joblib import numpy as np from transformers import AutoTokenizer, AutoModelForMaskedLM from Bio import SeqIO import io from sklearn.metrics import silhouette_score, silhouette_samples import matplotlib.pyplot as plt import seaborn as sns import os MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" HDBSCAN_MODEL_PATH = "hdbscan_model.pkl" MAX_LENGTH = 20 PLOTS_DIR = "plots" os.makedirs(PLOTS_DIR, exist_ok=True) print("Loading Transformer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device).eval() print("Transformer loaded.") print("Loading HDBSCAN...") clusterer = joblib.load(HDBSCAN_MODEL_PATH) print("HDBSCAN loaded.") def seq_to_kmers(seq, k=6): seq = seq.upper() return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)]) def analyze_fasta(fasta_bytes): try: # ✅ Decode bytes -> string -> StringIO (text mode) fasta_str = fasta_bytes.decode("utf-8", errors="ignore") fasta_io = io.StringIO(fasta_str) sequences = [] ids = [] for record in SeqIO.parse(fasta_io, "fasta"): ids.append(record.id) sequences.append(str(record.seq)) if not sequences: return { "overall_silhouette": 0, "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": "No sequences found"}] }, "plots/scatter.png", "plots/heatmap.png" # ✅ Do clustering (same as before) batch_kmers = [seq_to_kmers(s) for s in sequences] inputs = tokenizer( batch_kmers, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs, output_hidden_states=True) last_hidden = outputs.hidden_states[-1] mean_embeddings = last_hidden.mean(dim=1).cpu().numpy() labels = clusterer.fit_predict(mean_embeddings) strengths = [1.0 if l != -1 else 0.0 for l in labels] valid_mask = np.array(labels) != -1 silhouette_avg, per_sample_sil = 0, None if np.unique(np.array(labels)[valid_mask]).shape[0] > 1: silhouette_avg = silhouette_score(mean_embeddings[valid_mask], np.array(labels)[valid_mask]) results = [] for i, seq_id in enumerate(ids): result = { "id": seq_id, "cluster": int(labels[i]), "confidence": round(float(strengths[i]), 3), } if labels[i] == -1: result["note"] = "Potential novel/unknown sequence" results.append(result) return ( {"overall_silhouette": round(float(silhouette_avg), 3), "results": results}, "plots/scatter.png", # ✅ use existing saved scatter "plots/heatmap.png" # ✅ use existing saved heatmap ) except Exception as e: return { "overall_silhouette": 0, "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": f"Fallback: {str(e)}"}], }, "plots/scatter.png", "plots/heatmap.png" # Gradio UI demo = gr.Interface( fn=analyze_fasta, inputs=gr.File(file_types=[".fasta"], type="binary"), outputs=[gr.JSON(), gr.Image(), gr.Image()], title="DNA Clustering Analyzer", description="Upload a FASTA file → Get clustering results + scatter plot + heatmap." ) if __name__ == "__main__": demo.launch()