genomenet commited on
Commit
f44b2b9
·
1 Parent(s): 31ba0eb

Improve Space default prediction responsiveness

Browse files
Files changed (4) hide show
  1. README.md +2 -2
  2. REPORT_SUMMARY.md +12 -12
  3. app.py +22 -8
  4. inference/model_loader.py +1 -1
README.md CHANGED
@@ -48,12 +48,12 @@ git clone https://huggingface.co/spaces/genomenet/crispr-array-detection
48
  ### Space settings (HuggingFace web UI)
49
 
50
  - SDK: Docker
51
- - Hardware: T4 GPU
52
  - Visibility: Public
53
 
54
  ### Model weights
55
 
56
- Hosted at: https://huggingface.co/pmuench3/crispr-bert-model
57
 
58
  Downloaded automatically via `huggingface_hub` at startup.
59
 
 
48
  ### Space settings (HuggingFace web UI)
49
 
50
  - SDK: Docker
51
+ - Hardware: CPU Basic works for the default demo; T4 GPU is recommended for long sequences or low stride values
52
  - Visibility: Public
53
 
54
  ### Model weights
55
 
56
+ Hosted at: https://huggingface.co/genomenet/crispr-bert-model
57
 
58
  Downloaded automatically via `huggingface_hub` at startup.
59
 
REPORT_SUMMARY.md CHANGED
@@ -2,14 +2,14 @@
2
 
3
  **Date:** April 2026
4
  **Repository:** `/vol/hpcprojects/pmuench/crispr_tool/crispr-hf-space/`
5
- **HuggingFace Space:** https://huggingface.co/spaces/pmuench3/crispr-array-detection
6
- **HuggingFace Model Repository:** https://huggingface.co/pmuench3/crispr-bert-model
7
 
8
  ---
9
 
10
  ## Summary
11
 
12
- We have successfully deployed a publicly accessible web application for CRISPR array detection based on the BERT-based deep learning model developed in Ziyu Mu's Master's thesis. The application is hosted on HuggingFace Spaces with GPU acceleration (T4) and provides both interactive visualization and programmatic access to the model's predictions.
13
 
14
  ---
15
 
@@ -21,7 +21,7 @@ We have successfully deployed a publicly accessible web application for CRISPR a
21
  - **Output:** Per-position CRISPR probability scores (0-1)
22
  - **Visualization:** Interactive probability curve along sequence position
23
  - **Parameters:**
24
- - Configurable stride (50-500 bp, default 100)
25
  - Adjustable detection threshold (0.1-0.9, default 0.3)
26
  - **Region Detection:** Automatic identification and annotation of predicted CRISPR regions above threshold
27
 
@@ -101,7 +101,7 @@ plotly>=5.18.0
101
  From the original TODO:
102
 
103
  - [x] **Checkpoint beschaffen** - Model `best.h5` located and uploaded to HF Model Hub
104
- - [x] **Eigenes Repo anlegen** - Created HuggingFace Space `pmuench3/crispr-array-detection`
105
  - [x] **Code-Verständnis** - Analyzed custom layers, tokenization, sliding window logic
106
  - [x] **Model-Loader (Singleton)** - Implemented with HF Hub download
107
  - [x] **Tokenizer** - Extracted and adapted for inference
@@ -112,7 +112,7 @@ From the original TODO:
112
  - [x] **State-Dynamics Visualization** - UMAP + clustering + interactive Plotly plots
113
  - [x] **Input-Validation** - Sequence validation, FASTA header stripping
114
  - [x] **Health Endpoint equivalent** - GPU status shown in UI
115
- - [x] **Deployment** - Live on HuggingFace Spaces with T4 GPU
116
  - [x] **Acknowledgements** - Ziyu Mu, DFG SPP 2141, BMBF GenomeNet, HZI BIFO
117
 
118
  ---
@@ -121,7 +121,7 @@ From the original TODO:
121
 
122
  ### Web Interface
123
 
124
- 1. Navigate to https://huggingface.co/spaces/pmuench3/crispr-array-detection
125
  2. Paste DNA sequence (or use provided examples)
126
  3. Click "Analyze Sequence" for CRISPR detection
127
  4. Use "Embeddings" tab for State-Dynamic Plots
@@ -131,12 +131,12 @@ From the original TODO:
131
  ```python
132
  from gradio_client import Client
133
 
134
- client = Client("pmuench3/crispr-array-detection")
135
 
136
  # Predict CRISPR regions
137
  result = client.predict(
138
  sequence="ACGT...",
139
- stride=100,
140
  threshold=0.3,
141
  api_name="/predict"
142
  )
@@ -155,16 +155,16 @@ embedding = client.predict(
155
 
156
  ### Suggested Text (German)
157
 
158
- > Im Rahmen des SPP 2141 wurde ein öffentlich zugänglicher Webservice zur CRISPR-Array-Detektion entwickelt und auf HuggingFace Spaces bereitgestellt (https://huggingface.co/spaces/pmuench3/crispr-array-detection). Das System basiert auf einem BERT-basierten Deep-Learning-Modell (~430 Mio. Parameter), das auf metagenomischen und genomischen mikrobiellen Sequenzen vortrainiert und anschließend auf annotierten CRISPR-Arrays feinabgestimmt wurde.
159
  >
160
  > Der Service bietet:
161
  > - Vorhersage von CRISPR-Array-Wahrscheinlichkeiten entlang der Sequenzposition
162
  > - Extraktion von Hidden-State-Embeddings aus dem Transformer-Modell
163
  > - State-Dynamic-Plots zur Visualisierung der Einbettungstrajektorien mittels UMAP und Clustering
164
  >
165
- > Die State-Dynamic-Visualisierung ermöglicht die Identifikation wiederkehrender Strukturelemente (z.B. Repeats vs. Spacer) durch die Analyse der Aktivierungsmuster im neuronalen Netzwerk. Der Service läuft auf GPU-beschleunigter Hardware (NVIDIA T4) und ist für die wissenschaftliche Community frei zugänglich.
166
  >
167
- > **Referenz:** https://huggingface.co/spaces/pmuench3/crispr-array-detection
168
 
169
  ### Acknowledgements (for publication)
170
 
 
2
 
3
  **Date:** April 2026
4
  **Repository:** `/vol/hpcprojects/pmuench/crispr_tool/crispr-hf-space/`
5
+ **HuggingFace Space:** https://huggingface.co/spaces/genomenet/crispr-array-detection
6
+ **HuggingFace Model Repository:** https://huggingface.co/genomenet/crispr-bert-model
7
 
8
  ---
9
 
10
  ## Summary
11
 
12
+ We have successfully deployed a publicly accessible web application for CRISPR array detection based on the BERT-based deep learning model developed in Ziyu Mu's Master's thesis. The application is hosted on HuggingFace Spaces and provides both interactive visualization and programmatic access to the model's predictions.
13
 
14
  ---
15
 
 
21
  - **Output:** Per-position CRISPR probability scores (0-1)
22
  - **Visualization:** Interactive probability curve along sequence position
23
  - **Parameters:**
24
+ - Configurable stride (50-500 bp, default 500 for CPU responsiveness)
25
  - Adjustable detection threshold (0.1-0.9, default 0.3)
26
  - **Region Detection:** Automatic identification and annotation of predicted CRISPR regions above threshold
27
 
 
101
  From the original TODO:
102
 
103
  - [x] **Checkpoint beschaffen** - Model `best.h5` located and uploaded to HF Model Hub
104
+ - [x] **Eigenes Repo anlegen** - Created HuggingFace Space `genomenet/crispr-array-detection`
105
  - [x] **Code-Verständnis** - Analyzed custom layers, tokenization, sliding window logic
106
  - [x] **Model-Loader (Singleton)** - Implemented with HF Hub download
107
  - [x] **Tokenizer** - Extracted and adapted for inference
 
112
  - [x] **State-Dynamics Visualization** - UMAP + clustering + interactive Plotly plots
113
  - [x] **Input-Validation** - Sequence validation, FASTA header stripping
114
  - [x] **Health Endpoint equivalent** - GPU status shown in UI
115
+ - [x] **Deployment** - Live on HuggingFace Spaces; GPU hardware is recommended for long sequences
116
  - [x] **Acknowledgements** - Ziyu Mu, DFG SPP 2141, BMBF GenomeNet, HZI BIFO
117
 
118
  ---
 
121
 
122
  ### Web Interface
123
 
124
+ 1. Navigate to https://huggingface.co/spaces/genomenet/crispr-array-detection
125
  2. Paste DNA sequence (or use provided examples)
126
  3. Click "Analyze Sequence" for CRISPR detection
127
  4. Use "Embeddings" tab for State-Dynamic Plots
 
131
  ```python
132
  from gradio_client import Client
133
 
134
+ client = Client("genomenet/crispr-array-detection")
135
 
136
  # Predict CRISPR regions
137
  result = client.predict(
138
  sequence="ACGT...",
139
+ stride=500,
140
  threshold=0.3,
141
  api_name="/predict"
142
  )
 
155
 
156
  ### Suggested Text (German)
157
 
158
+ > Im Rahmen des SPP 2141 wurde ein öffentlich zugänglicher Webservice zur CRISPR-Array-Detektion entwickelt und auf HuggingFace Spaces bereitgestellt (https://huggingface.co/spaces/genomenet/crispr-array-detection). Das System basiert auf einem BERT-basierten Deep-Learning-Modell (~430 Mio. Parameter), das auf metagenomischen und genomischen mikrobiellen Sequenzen vortrainiert und anschließend auf annotierten CRISPR-Arrays feinabgestimmt wurde.
159
  >
160
  > Der Service bietet:
161
  > - Vorhersage von CRISPR-Array-Wahrscheinlichkeiten entlang der Sequenzposition
162
  > - Extraktion von Hidden-State-Embeddings aus dem Transformer-Modell
163
  > - State-Dynamic-Plots zur Visualisierung der Einbettungstrajektorien mittels UMAP und Clustering
164
  >
165
+ > Die State-Dynamic-Visualisierung ermöglicht die Identifikation wiederkehrender Strukturelemente (z.B. Repeats vs. Spacer) durch die Analyse der Aktivierungsmuster im neuronalen Netzwerk. GPU-Hardware wird für lange Sequenzen und hohe Auflösung empfohlen. Der Service ist für die wissenschaftliche Community frei zugänglich.
166
  >
167
+ > **Referenz:** https://huggingface.co/spaces/genomenet/crispr-array-detection
168
 
169
  ### Acknowledgements (for publication)
170
 
app.py CHANGED
@@ -4,6 +4,7 @@ CRISPR Array Detection - HuggingFace Spaces App
4
 
5
  import os
6
  import html
 
7
  import tempfile
8
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
  os.environ.setdefault("MPLCONFIGDIR", os.path.join(tempfile.gettempdir(), "matplotlib"))
@@ -26,10 +27,15 @@ from inference.model_loader import get_model, warmup_model, get_gpu_status
26
  from inference.tokenizer import validate_sequence, strip_fasta_header
27
  from inference.inference import detect_crispr_regions
28
 
 
 
 
29
  MAX_SEQUENCE_LENGTH = int(os.environ.get("MAX_SEQUENCE_LENGTH", "50000"))
30
  MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_BYTES", str(2 * 1024 * 1024)))
31
  MAX_SEQUENCE_VIEWER_LENGTH = int(os.environ.get("MAX_SEQUENCE_VIEWER_LENGTH", "20000"))
32
  QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "8"))
 
 
33
 
34
  # Custom CSS - Minimal monochrome design with Geist fonts
35
  CUSTOM_CSS = """
@@ -915,7 +921,7 @@ def create_sequence_viewer_html(sequence, positions, probabilities, threshold=0.
915
  return ''.join(html_parts)
916
 
917
 
918
- def predict(sequence: str, stride: int = 100, threshold: float = 0.3):
919
  """Predict CRISPR array probability for each position."""
920
  import csv
921
  import time
@@ -1037,7 +1043,7 @@ def detect(sequence: str, threshold: float = 0.3, min_length: int = 160):
1037
  sequence,
1038
  threshold=threshold,
1039
  min_length=min_length,
1040
- stride=100
1041
  )
1042
 
1043
  if not regions:
@@ -1194,9 +1200,9 @@ Sliding window analysis with per-position probability scores. Export to GFF3/CSV
1194
  )
1195
  with gr.Row():
1196
  stride_input = gr.Slider(
1197
- minimum=50, maximum=500, value=100, step=50,
1198
  label="stride",
1199
- info="lower = higher resolution"
1200
  )
1201
  threshold_input = gr.Slider(
1202
  minimum=0.1, maximum=0.9, value=0.3, step=0.05,
@@ -1252,7 +1258,11 @@ Sliding window analysis with per-position probability scores. Export to GFF3/CSV
1252
  )
1253
 
1254
  def predict_and_show_downloads(*args):
1255
- results = predict(*args)
 
 
 
 
1256
  # results = (fig, summary, regions, png, pdf, csv, summary_txt, gff, seq_html)
1257
  # Return results plus visibility updates for accordions
1258
  success = results[0] is not None
@@ -1322,7 +1332,11 @@ Repeats cluster together, spacers form distinct groups.
1322
  )
1323
 
1324
  def embed_and_show_downloads(*args):
1325
- results = get_embedding(*args)
 
 
 
 
1326
  success = results[0] is not None
1327
  return results + (gr.update(visible=success),)
1328
 
@@ -1346,7 +1360,7 @@ client = Client("genomenet/crispr-array-detection")
1346
  # predict
1347
  result = client.predict(
1348
  sequence="ATGC...",
1349
- stride=100,
1350
  threshold=0.3,
1351
  api_name="/predict"
1352
  )
@@ -1384,7 +1398,7 @@ pip install -r requirements.txt && python app.py
1384
 
1385
  | param | default | range |
1386
  |-------|---------|-------|
1387
- | stride | 100 bp | 50-500 |
1388
  | threshold | 0.3 | 0.1-0.9 |
1389
 
1390
  **citation**
 
4
 
5
  import os
6
  import html
7
+ import logging
8
  import tempfile
9
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
10
  os.environ.setdefault("MPLCONFIGDIR", os.path.join(tempfile.gettempdir(), "matplotlib"))
 
27
  from inference.tokenizer import validate_sequence, strip_fasta_header
28
  from inference.inference import detect_crispr_regions
29
 
30
+ logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
31
+ logger = logging.getLogger(__name__)
32
+
33
  MAX_SEQUENCE_LENGTH = int(os.environ.get("MAX_SEQUENCE_LENGTH", "50000"))
34
  MAX_UPLOAD_BYTES = int(os.environ.get("MAX_UPLOAD_BYTES", str(2 * 1024 * 1024)))
35
  MAX_SEQUENCE_VIEWER_LENGTH = int(os.environ.get("MAX_SEQUENCE_VIEWER_LENGTH", "20000"))
36
  QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "8"))
37
+ DEFAULT_STRIDE = int(os.environ.get("DEFAULT_STRIDE", "500"))
38
+ DEFAULT_THRESHOLD = float(os.environ.get("DEFAULT_THRESHOLD", "0.3"))
39
 
40
  # Custom CSS - Minimal monochrome design with Geist fonts
41
  CUSTOM_CSS = """
 
921
  return ''.join(html_parts)
922
 
923
 
924
+ def predict(sequence: str, stride: int = DEFAULT_STRIDE, threshold: float = DEFAULT_THRESHOLD):
925
  """Predict CRISPR array probability for each position."""
926
  import csv
927
  import time
 
1043
  sequence,
1044
  threshold=threshold,
1045
  min_length=min_length,
1046
+ stride=DEFAULT_STRIDE
1047
  )
1048
 
1049
  if not regions:
 
1200
  )
1201
  with gr.Row():
1202
  stride_input = gr.Slider(
1203
+ minimum=50, maximum=500, value=DEFAULT_STRIDE, step=50,
1204
  label="stride",
1205
+ info="500 = fast on CPU; lower = higher resolution but slower"
1206
  )
1207
  threshold_input = gr.Slider(
1208
  minimum=0.1, maximum=0.9, value=0.3, step=0.05,
 
1258
  )
1259
 
1260
  def predict_and_show_downloads(*args):
1261
+ try:
1262
+ results = predict(*args)
1263
+ except Exception as exc:
1264
+ logger.exception("Prediction failed")
1265
+ results = prediction_error_outputs(f"Analysis failed: {exc}")
1266
  # results = (fig, summary, regions, png, pdf, csv, summary_txt, gff, seq_html)
1267
  # Return results plus visibility updates for accordions
1268
  success = results[0] is not None
 
1332
  )
1333
 
1334
  def embed_and_show_downloads(*args):
1335
+ try:
1336
+ results = get_embedding(*args)
1337
+ except Exception as exc:
1338
+ logger.exception("Embedding failed")
1339
+ results = embedding_error_outputs(f"Embedding failed: {exc}")
1340
  success = results[0] is not None
1341
  return results + (gr.update(visible=success),)
1342
 
 
1360
  # predict
1361
  result = client.predict(
1362
  sequence="ATGC...",
1363
+ stride=500,
1364
  threshold=0.3,
1365
  api_name="/predict"
1366
  )
 
1398
 
1399
  | param | default | range |
1400
  |-------|---------|-------|
1401
+ | stride | 500 bp | 50-500 |
1402
  | threshold | 0.3 | 0.1-0.9 |
1403
 
1404
  **citation**
inference/model_loader.py CHANGED
@@ -25,7 +25,7 @@ _embedding_model: Optional[tf.keras.Model] = None
25
  _model_lock = threading.Lock()
26
 
27
  # HuggingFace model repository
28
- HF_MODEL_REPO = os.environ.get("CRISPR_HF_REPO", "pmuench3/crispr-bert-model")
29
  HF_MODEL_FILENAME = os.environ.get("CRISPR_HF_FILENAME", "best.h5")
30
 
31
  # Local model path (optional override)
 
25
  _model_lock = threading.Lock()
26
 
27
  # HuggingFace model repository
28
+ HF_MODEL_REPO = os.environ.get("CRISPR_HF_REPO", "genomenet/crispr-bert-model")
29
  HF_MODEL_FILENAME = os.environ.get("CRISPR_HF_FILENAME", "best.h5")
30
 
31
  # Local model path (optional override)