jebin511 commited on
Commit
af1e5ec
·
verified ·
1 Parent(s): 69bf1ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -47
app.py CHANGED
@@ -1,81 +1,55 @@
 
1
  import gradio as gr
2
  import joblib
3
  import numpy as np
4
  from collections import Counter
5
  from typing import List
6
- import os
7
-
8
- # --- Helper Functions ---
9
- BASES = ['A', 'T', 'C', 'G']
10
 
 
11
  def kmer_counts(seq: str, k=3):
12
  seq = seq.strip().upper()
13
  counts = Counter()
14
  if len(seq) < k:
15
  return counts
16
- for i in range(len(seq) - k + 1):
17
  counts[seq[i:i+k]] += 1
18
  return counts
19
 
20
  def vectorize_single(seq: str, vocab: List[str], k=3):
21
- X = np.zeros((1, len(vocab)), dtype=float)
22
  c = kmer_counts(seq, k)
23
- for j, kmer in enumerate(vocab):
24
- X[0, j] = c.get(kmer, 0)
25
- return X
26
-
27
- # --- Load Model ---
28
- MODEL_PATH = "mutation_model.joblib"
29
-
30
- if not os.path.exists(MODEL_PATH):
31
- raise FileNotFoundError(
32
- f"⚠️ Model file '{MODEL_PATH}' not found. "
33
- "Please upload 'mutation_model.joblib' along with this app."
34
- )
35
 
36
- model, vocab = joblib.load(MODEL_PATH)
 
37
 
38
- # --- Prediction Logic ---
39
  def predict_sequence(sequence: str):
40
  if not sequence or len(sequence.strip()) < 3:
41
- return {"error": "Please enter a valid DNA sequence (≥3 bases)."}
42
-
43
  X = vectorize_single(sequence, vocab=vocab, k=3)
44
  pred = model.predict(X)[0]
45
  prob = float(model.predict_proba(X).max()) if hasattr(model, "predict_proba") else None
46
-
47
  return {
48
  "sequence": sequence,
49
  "mutation_detected": bool(pred),
50
- "confidence": round(prob, 3) if prob else "N/A"
51
  }
52
 
53
- # --- Gradio Interface ---
54
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
55
- gr.Markdown(
56
- """
57
- <h1 style="text-align:center;">🧬 DNA Mutation Analyzer</h1>
58
- <p style="text-align:center;">
59
- Upload or paste a DNA sequence to check for possible mutations using a Random Forest ML model.
60
- </p>
61
- """
62
- )
63
-
64
- with gr.Row():
65
- seq_input = gr.Textbox(
66
- label="DNA Sequence",
67
- placeholder="Enter sequence like ATGCGTACGTTAGC...",
68
- lines=2,
69
- )
70
- analyze_btn = gr.Button("🔍 Analyze Sequence")
71
- result = gr.JSON(label="Analysis Result")
72
-
73
- analyze_btn.click(fn=predict_sequence, inputs=seq_input, outputs=result)
74
 
75
- # --- API Endpoint for Programmatic Access ---
 
76
  def api_predict(payload: dict):
77
  seq = payload.get("sequence", "")
78
  return predict_sequence(seq)
79
 
80
  if __name__ == "__main__":
81
- demo.launch()
 
1
+ # app.py
2
  import gradio as gr
3
  import joblib
4
  import numpy as np
5
  from collections import Counter
6
  from typing import List
 
 
 
 
7
 
8
+ # helper: k-mer extraction / vectorize (k=3)
9
  def kmer_counts(seq: str, k=3):
10
  seq = seq.strip().upper()
11
  counts = Counter()
12
  if len(seq) < k:
13
  return counts
14
+ for i in range(len(seq)-k+1):
15
  counts[seq[i:i+k]] += 1
16
  return counts
17
 
18
  def vectorize_single(seq: str, vocab: List[str], k=3):
19
+ x = np.zeros((1, len(vocab)), dtype=float)
20
  c = kmer_counts(seq, k)
21
+ for j,kmer in enumerate(vocab):
22
+ x[0,j] = c.get(kmer, 0)
23
+ return x
 
 
 
 
 
 
 
 
 
24
 
25
+ # load model+vocab (mutation_model.joblib must be uploaded too)
26
+ model, vocab = joblib.load("mutation_model.joblib")
27
 
 
28
  def predict_sequence(sequence: str):
29
  if not sequence or len(sequence.strip()) < 3:
30
+ return {"error":"sequence too short"}
 
31
  X = vectorize_single(sequence, vocab=vocab, k=3)
32
  pred = model.predict(X)[0]
33
  prob = float(model.predict_proba(X).max()) if hasattr(model, "predict_proba") else None
 
34
  return {
35
  "sequence": sequence,
36
  "mutation_detected": bool(pred),
37
+ "confidence": prob
38
  }
39
 
40
+ # Gradio UI
41
+ with gr.Blocks() as demo:
42
+ gr.Markdown("# DNA Mutation Detector (Quick Space)")
43
+ seq_in = gr.Textbox(label="DNA sequence", placeholder="ATGCGTACGTTAGC...")
44
+ btn = gr.Button("Analyze")
45
+ out = gr.JSON()
46
+ btn.click(fn=predict_sequence, inputs=seq_in, outputs=out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Expose a simple inference API endpoint (Gradio provides /api/predict automatically)
49
+ # but we also expose a programmatic function name for convenience:
50
  def api_predict(payload: dict):
51
  seq = payload.get("sequence", "")
52
  return predict_sequence(seq)
53
 
54
  if __name__ == "__main__":
55
+ demo.launch() change