Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import joblib | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from Bio import SeqIO | |
| import io | |
| from sklearn.metrics import silhouette_score, silhouette_samples | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import os | |
| MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species" | |
| HDBSCAN_MODEL_PATH = "hdbscan_model.pkl" | |
| MAX_LENGTH = 20 | |
| PLOTS_DIR = "plots" | |
| os.makedirs(PLOTS_DIR, exist_ok=True) | |
| print("Loading Transformer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device).eval() | |
| print("Transformer loaded.") | |
| print("Loading HDBSCAN...") | |
| clusterer = joblib.load(HDBSCAN_MODEL_PATH) | |
| print("HDBSCAN loaded.") | |
| def seq_to_kmers(seq, k=6): | |
| seq = seq.upper() | |
| return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)]) | |
| def analyze_fasta(fasta_bytes): | |
| try: | |
| # β Decode bytes -> string -> StringIO (text mode) | |
| fasta_str = fasta_bytes.decode("utf-8", errors="ignore") | |
| fasta_io = io.StringIO(fasta_str) | |
| sequences = [] | |
| ids = [] | |
| for record in SeqIO.parse(fasta_io, "fasta"): | |
| ids.append(record.id) | |
| sequences.append(str(record.seq)) | |
| if not sequences: | |
| return { | |
| "overall_silhouette": 0, | |
| "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": "No sequences found"}] | |
| }, "plots/scatter.png", "plots/heatmap.png" | |
| # β Do clustering (same as before) | |
| batch_kmers = [seq_to_kmers(s) for s in sequences] | |
| inputs = tokenizer( | |
| batch_kmers, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs, output_hidden_states=True) | |
| last_hidden = outputs.hidden_states[-1] | |
| mean_embeddings = last_hidden.mean(dim=1).cpu().numpy() | |
| labels = clusterer.fit_predict(mean_embeddings) | |
| strengths = [1.0 if l != -1 else 0.0 for l in labels] | |
| valid_mask = np.array(labels) != -1 | |
| silhouette_avg, per_sample_sil = 0, None | |
| if np.unique(np.array(labels)[valid_mask]).shape[0] > 1: | |
| silhouette_avg = silhouette_score(mean_embeddings[valid_mask], np.array(labels)[valid_mask]) | |
| results = [] | |
| for i, seq_id in enumerate(ids): | |
| result = { | |
| "id": seq_id, | |
| "cluster": int(labels[i]), | |
| "confidence": round(float(strengths[i]), 3), | |
| } | |
| if labels[i] == -1: | |
| result["note"] = "Potential novel/unknown sequence" | |
| results.append(result) | |
| return ( | |
| {"overall_silhouette": round(float(silhouette_avg), 3), "results": results}, | |
| "plots/scatter.png", # β use existing saved scatter | |
| "plots/heatmap.png" # β use existing saved heatmap | |
| ) | |
| except Exception as e: | |
| return { | |
| "overall_silhouette": 0, | |
| "results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": f"Fallback: {str(e)}"}], | |
| }, "plots/scatter.png", "plots/heatmap.png" | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=analyze_fasta, | |
| inputs=gr.File(file_types=[".fasta"], type="binary"), | |
| outputs=[gr.JSON(), gr.Image(), gr.Image()], | |
| title="DNA Clustering Analyzer", | |
| description="Upload a FASTA file β Get clustering results + scatter plot + heatmap." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |