jebin511 commited on
Commit
e1678bf
·
verified ·
1 Parent(s): af1e5ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -21
app.py CHANGED
@@ -1,55 +1,81 @@
1
- # app.py
2
  import gradio as gr
3
  import joblib
4
  import numpy as np
5
  from collections import Counter
6
  from typing import List
 
 
 
 
7
 
8
- # helper: k-mer extraction / vectorize (k=3)
9
  def kmer_counts(seq: str, k=3):
10
  seq = seq.strip().upper()
11
  counts = Counter()
12
  if len(seq) < k:
13
  return counts
14
- for i in range(len(seq)-k+1):
15
  counts[seq[i:i+k]] += 1
16
  return counts
17
 
18
  def vectorize_single(seq: str, vocab: List[str], k=3):
19
- x = np.zeros((1, len(vocab)), dtype=float)
20
  c = kmer_counts(seq, k)
21
- for j,kmer in enumerate(vocab):
22
- x[0,j] = c.get(kmer, 0)
23
- return x
 
 
 
 
 
 
 
 
 
24
 
25
- # load model+vocab (mutation_model.joblib must be uploaded too)
26
- model, vocab = joblib.load("mutation_model.joblib")
27
 
 
28
  def predict_sequence(sequence: str):
29
  if not sequence or len(sequence.strip()) < 3:
30
- return {"error":"sequence too short"}
 
31
  X = vectorize_single(sequence, vocab=vocab, k=3)
32
  pred = model.predict(X)[0]
33
  prob = float(model.predict_proba(X).max()) if hasattr(model, "predict_proba") else None
 
34
  return {
35
  "sequence": sequence,
36
  "mutation_detected": bool(pred),
37
- "confidence": prob
38
  }
39
 
40
- # Gradio UI
41
- with gr.Blocks() as demo:
42
- gr.Markdown("# DNA Mutation Detector (Quick Space)")
43
- seq_in = gr.Textbox(label="DNA sequence", placeholder="ATGCGTACGTTAGC...")
44
- btn = gr.Button("Analyze")
45
- out = gr.JSON()
46
- btn.click(fn=predict_sequence, inputs=seq_in, outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Expose a simple inference API endpoint (Gradio provides /api/predict automatically)
49
- # but we also expose a programmatic function name for convenience:
50
  def api_predict(payload: dict):
51
  seq = payload.get("sequence", "")
52
  return predict_sequence(seq)
53
 
54
  if __name__ == "__main__":
55
- demo.launch() change
 
 
1
  import gradio as gr
2
  import joblib
3
  import numpy as np
4
  from collections import Counter
5
  from typing import List
6
+ import os
7
+
8
+ # --- Helper Functions ---
9
+ BASES = ['A', 'T', 'C', 'G']
10
 
 
11
  def kmer_counts(seq: str, k=3):
12
  seq = seq.strip().upper()
13
  counts = Counter()
14
  if len(seq) < k:
15
  return counts
16
+ for i in range(len(seq) - k + 1):
17
  counts[seq[i:i+k]] += 1
18
  return counts
19
 
20
  def vectorize_single(seq: str, vocab: List[str], k=3):
21
+ X = np.zeros((1, len(vocab)), dtype=float)
22
  c = kmer_counts(seq, k)
23
+ for j, kmer in enumerate(vocab):
24
+ X[0, j] = c.get(kmer, 0)
25
+ return X
26
+
27
+ # --- Load Model ---
28
+ MODEL_PATH = "mutation_model.joblib"
29
+
30
+ if not os.path.exists(MODEL_PATH):
31
+ raise FileNotFoundError(
32
+ f"⚠️ Model file '{MODEL_PATH}' not found. "
33
+ "Please upload 'mutation_model.joblib' along with this app."
34
+ )
35
 
36
+ model, vocab = joblib.load(MODEL_PATH)
 
37
 
38
+ # --- Prediction Logic ---
39
  def predict_sequence(sequence: str):
40
  if not sequence or len(sequence.strip()) < 3:
41
+ return {"error": "Please enter a valid DNA sequence (≥3 bases)."}
42
+
43
  X = vectorize_single(sequence, vocab=vocab, k=3)
44
  pred = model.predict(X)[0]
45
  prob = float(model.predict_proba(X).max()) if hasattr(model, "predict_proba") else None
46
+
47
  return {
48
  "sequence": sequence,
49
  "mutation_detected": bool(pred),
50
+ "confidence": round(prob, 3) if prob else "N/A"
51
  }
52
 
53
+ # --- Gradio Interface ---
54
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
55
+ gr.Markdown(
56
+ """
57
+ <h1 style="text-align:center;">🧬 DNA Mutation Analyzer</h1>
58
+ <p style="text-align:center;">
59
+ Upload or paste a DNA sequence to check for possible mutations using a Random Forest ML model.
60
+ </p>
61
+ """
62
+ )
63
+
64
+ with gr.Row():
65
+ seq_input = gr.Textbox(
66
+ label="DNA Sequence",
67
+ placeholder="Enter sequence like ATGCGTACGTTAGC...",
68
+ lines=2,
69
+ )
70
+ analyze_btn = gr.Button("🔍 Analyze Sequence")
71
+ result = gr.JSON(label="Analysis Result")
72
+
73
+ analyze_btn.click(fn=predict_sequence, inputs=seq_input, outputs=result)
74
 
75
+ # --- API Endpoint for Programmatic Access ---
 
76
  def api_predict(payload: dict):
77
  seq = payload.get("sequence", "")
78
  return predict_sequence(seq)
79
 
80
  if __name__ == "__main__":
81
+ demo.launch()