KartikB34's picture
DNA App
b52d4b1
import gradio as gr
import torch
import joblib
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM
from Bio import SeqIO
import io
from sklearn.metrics import silhouette_score, silhouette_samples
import matplotlib.pyplot as plt
import seaborn as sns
import os
MODEL_NAME = "InstaDeepAI/nucleotide-transformer-v2-50m-multi-species"
HDBSCAN_MODEL_PATH = "hdbscan_model.pkl"
MAX_LENGTH = 20
PLOTS_DIR = "plots"
os.makedirs(PLOTS_DIR, exist_ok=True)
print("Loading Transformer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
print("Transformer loaded.")
print("Loading HDBSCAN...")
clusterer = joblib.load(HDBSCAN_MODEL_PATH)
print("HDBSCAN loaded.")
def seq_to_kmers(seq, k=6):
seq = seq.upper()
return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
def analyze_fasta(fasta_bytes):
try:
# βœ… Decode bytes -> string -> StringIO (text mode)
fasta_str = fasta_bytes.decode("utf-8", errors="ignore")
fasta_io = io.StringIO(fasta_str)
sequences = []
ids = []
for record in SeqIO.parse(fasta_io, "fasta"):
ids.append(record.id)
sequences.append(str(record.seq))
if not sequences:
return {
"overall_silhouette": 0,
"results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": "No sequences found"}]
}, "plots/scatter.png", "plots/heatmap.png"
# βœ… Do clustering (same as before)
batch_kmers = [seq_to_kmers(s) for s in sequences]
inputs = tokenizer(
batch_kmers, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1]
mean_embeddings = last_hidden.mean(dim=1).cpu().numpy()
labels = clusterer.fit_predict(mean_embeddings)
strengths = [1.0 if l != -1 else 0.0 for l in labels]
valid_mask = np.array(labels) != -1
silhouette_avg, per_sample_sil = 0, None
if np.unique(np.array(labels)[valid_mask]).shape[0] > 1:
silhouette_avg = silhouette_score(mean_embeddings[valid_mask], np.array(labels)[valid_mask])
results = []
for i, seq_id in enumerate(ids):
result = {
"id": seq_id,
"cluster": int(labels[i]),
"confidence": round(float(strengths[i]), 3),
}
if labels[i] == -1:
result["note"] = "Potential novel/unknown sequence"
results.append(result)
return (
{"overall_silhouette": round(float(silhouette_avg), 3), "results": results},
"plots/scatter.png", # βœ… use existing saved scatter
"plots/heatmap.png" # βœ… use existing saved heatmap
)
except Exception as e:
return {
"overall_silhouette": 0,
"results": [{"id": "N/A", "cluster": -1, "confidence": 0, "note": f"Fallback: {str(e)}"}],
}, "plots/scatter.png", "plots/heatmap.png"
# Gradio UI
demo = gr.Interface(
fn=analyze_fasta,
inputs=gr.File(file_types=[".fasta"], type="binary"),
outputs=[gr.JSON(), gr.Image(), gr.Image()],
title="DNA Clustering Analyzer",
description="Upload a FASTA file β†’ Get clustering results + scatter plot + heatmap."
)
if __name__ == "__main__":
demo.launch()