devjas1 commited on
Commit
182c9ce
·
1 Parent(s): b1b7e3c

(FIX): Streamline remove redundant scripts.preprocess_dataset import; enhance resampling logic with diagnosis

Browse files

- Refactors spectrum resampling + improve diagnostics
- Introduces robust checks for strictly increasing sequences in resampling results, logging ambigous cases
- Adds session state persistence for both raw and resampled data, and enriches diagnostics with detailed statistics about the resampled data

Files changed (1) hide show
  1. app.py +61 -22
app.py CHANGED
@@ -1,6 +1,5 @@
1
  from models.resnet_cnn import ResNet1D
2
  from models.figure2_cnn import Figure2CNN
3
- import logging
4
  import hashlib
5
  import gc
6
  import time
@@ -22,14 +21,7 @@ if utils_path.is_dir() and str(utils_path) not in sys.path:
22
  matplotlib.use("Agg") # ensure headless rendering in Spaces
23
 
24
  # Import local modules
25
- # Prefer canonical script; fallback to local utils for HF hard-copy scenario
26
- try:
27
- from scripts.preprocess_dataset import resample_spectrum
28
- except (ImportError, ModuleNotFoundError):
29
- try:
30
- from utils.preprocessing import resample_spectrum
31
- except (ImportError, ModuleNotFoundError):
32
- raise ImportError("Could not import 'resample_spectrum' from either 'scripts.preprocess_dataset' or 'utils.preprocessing'. Please ensure the function exists in one of these modules.")
33
 
34
  KEEP_KEYS = {
35
  # === global UI context we want to keep after "Reset" ===
@@ -129,7 +121,7 @@ def label_file(filename: str) -> int:
129
  def load_state_dict(_mtime, model_path):
130
  """Load state dict with mtime in cache key to detect file changes"""
131
  try:
132
- return torch.load(model_path, map_location="cpu")
133
  except (FileNotFoundError, RuntimeError) as e:
134
  st.warning(f"Error loading state dict: {e}")
135
  return None
@@ -235,11 +227,11 @@ def parse_spectrum_data(raw_text):
235
  return x, y
236
 
237
 
238
- def create_spectrum_plot(x_raw, y_raw, y_resampled):
239
  """Create spectrum visualization plot"""
240
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
241
 
242
- # Raw spectrum
243
  ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
244
  ax[0].set_title("Raw Input Spectrum")
245
  ax[0].set_xlabel("Wavenumber (cm⁻¹)")
@@ -247,19 +239,16 @@ def create_spectrum_plot(x_raw, y_raw, y_resampled):
247
  ax[0].grid(True, alpha=0.3)
248
  ax[0].legend()
249
 
250
- # Resampled spectrum
251
- x_resampled = np.linspace(min(x_raw), max(x_raw), TARGET_LEN)
252
- ax[1].plot(x_resampled, y_resampled, label="Resampled",
253
- color="steelblue", linewidth=1)
254
- ax[1].set_title(f"Resampled ({TARGET_LEN} points)")
255
  ax[1].set_xlabel("Wavenumber (cm⁻¹)")
256
  ax[1].set_ylabel("Intensity")
257
  ax[1].grid(True, alpha=0.3)
258
  ax[1].legend()
259
 
260
  plt.tight_layout()
261
-
262
- # Convert to image
263
  buf = io.BytesIO()
264
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
265
  buf.seek(0)
@@ -546,7 +535,30 @@ def main():
546
 
547
  # Resample
548
  with st.spinner("Resampling spectrum..."):
549
- _, y_resampled = resample_spectrum(x_raw, y_raw, TARGET_LEN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  # Persist results (drives right column)
552
  st.session_state["x_raw"] = x_raw
@@ -571,6 +583,7 @@ def main():
571
  # Get data from session state
572
  x_raw = st.session_state.get('x_raw')
573
  y_raw = st.session_state.get('y_raw')
 
574
  y_resampled = st.session_state.get('y_resampled')
575
  filename = st.session_state.get('filename', 'Unknown')
576
 
@@ -578,8 +591,7 @@ def main():
578
 
579
  # Create and display plot
580
  try:
581
- spectrum_plot = create_spectrum_plot(
582
- x_raw, y_raw, y_resampled)
583
  st.image(
584
  spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
585
  except (ValueError, RuntimeError, TypeError) as e:
@@ -706,6 +718,33 @@ def main():
706
  st.text_area("Logs", "\n".join(
707
  st.session_state.get("log_messages", [])), height=200)
708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
709
  with tab3:
710
  st.markdown("""
711
  **🔍 Analysis Process**
 
1
  from models.resnet_cnn import ResNet1D
2
  from models.figure2_cnn import Figure2CNN
 
3
  import hashlib
4
  import gc
5
  import time
 
21
  matplotlib.use("Agg") # ensure headless rendering in Spaces
22
 
23
  # Import local modules
24
+ from utils.preprocessing import resample_spectrum
 
 
 
 
 
 
 
25
 
26
  KEEP_KEYS = {
27
  # === global UI context we want to keep after "Reset" ===
 
121
  def load_state_dict(_mtime, model_path):
122
  """Load state dict with mtime in cache key to detect file changes"""
123
  try:
124
+ return torch.load(model_path, map_location="cpu", weights_only=True)
125
  except (FileNotFoundError, RuntimeError) as e:
126
  st.warning(f"Error loading state dict: {e}")
127
  return None
 
227
  return x, y
228
 
229
 
230
+ def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
231
  """Create spectrum visualization plot"""
232
  fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
233
 
234
+ # == Raw spectrum ==
235
  ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
236
  ax[0].set_title("Raw Input Spectrum")
237
  ax[0].set_xlabel("Wavenumber (cm⁻¹)")
 
239
  ax[0].grid(True, alpha=0.3)
240
  ax[0].legend()
241
 
242
+ # == Resampled spectrum ==
243
+ ax[1].plot(x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1)
244
+ ax[1].set_title(f"Resampled ({len(y_resampled)} points)")
 
 
245
  ax[1].set_xlabel("Wavenumber (cm⁻¹)")
246
  ax[1].set_ylabel("Intensity")
247
  ax[1].grid(True, alpha=0.3)
248
  ax[1].legend()
249
 
250
  plt.tight_layout()
251
+ # == Convert to image ==
 
252
  buf = io.BytesIO()
253
  plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
254
  buf.seek(0)
 
535
 
536
  # Resample
537
  with st.spinner("Resampling spectrum..."):
538
+ # ===Resample Unpack===
539
+ r1, r2 = resample_spectrum(x_raw, y_raw, TARGET_LEN)
540
+
541
+ def _is_strictly_increasing(a):
542
+ try:
543
+ a = np.asarray(a)
544
+ return a.ndim == 1 and a.size >= 2 and np.all(np.diff(a) > 0)
545
+ except Exception:
546
+ return False
547
+
548
+ if _is_strictly_increasing(r1) and not _is_strictly_increasing(r2):
549
+ x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
550
+ elif _is_strictly_increasing(r2) and not _is_strictly_increasing(r1):
551
+ x_resampled, y_resampled = np.asarray(r2), np.asarray(r1)
552
+ else:
553
+ # == Ambigous; assume (x, y) and log
554
+ x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
555
+ log_message("Resample outputs ambigous; assumed (x, y).")
556
+
557
+ # ===Persists for plotting + inference===
558
+ st.session_state["x_raw"] = x_raw
559
+ st.session_state["y_raw"] = y_raw
560
+ st.session_state["x_resampled"] = x_resampled # ←-- NEW
561
+ st.session_state["y_resampled"] = y_resampled
562
 
563
  # Persist results (drives right column)
564
  st.session_state["x_raw"] = x_raw
 
583
  # Get data from session state
584
  x_raw = st.session_state.get('x_raw')
585
  y_raw = st.session_state.get('y_raw')
586
+ x_resampled = st.session_state.get('x_resampled') # ← NEW
587
  y_resampled = st.session_state.get('y_resampled')
588
  filename = st.session_state.get('filename', 'Unknown')
589
 
 
591
 
592
  # Create and display plot
593
  try:
594
+ spectrum_plot = create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled)
 
595
  st.image(
596
  spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
597
  except (ValueError, RuntimeError, TypeError) as e:
 
718
  st.text_area("Logs", "\n".join(
719
  st.session_state.get("log_messages", [])), height=200)
720
 
721
+ try:
722
+ resampler_mod = getattr(resample_spectrum, "__module__", "unknown")
723
+ resampler_doc = getattr(resample_spectrum, "__doc__", None)
724
+ resampler_doc = resampler_doc.splitlines()[0] if isinstance(resampler_doc, str) and resampler_doc else "no doc"
725
+
726
+ y_rs = st.session_state.get("y_resampled", None)
727
+ diag = {}
728
+ if y_rs is not None:
729
+ arr = np.asarray(y_rs)
730
+ diag = {
731
+ "y_resampled_len": int(arr.size),
732
+ "y_resampled_min": float(np.min(arr)) if arr.size else None,
733
+ "y_resampled_max": float(np.max(arr)) if arr.size else None,
734
+ "y_resampled_ptp": float(np.ptp(arr)) if arr.size else None,
735
+ "y_resampled_unique": int(np.unique(arr).size) if arr.size else None,
736
+ "y_resampled_all_equal": bool(np.ptp(arr) == 0.0) if arr.size else None,
737
+ }
738
+
739
+ st.markdown("**Resampler Info")
740
+ st.json({
741
+ "module": resampler_mod,
742
+ "doc": resampler_doc,
743
+ **({"y_resampled_stats": diag} if diag else {})
744
+ })
745
+ except Exception as _e:
746
+ st.warning(f"Diagnostics skipped: {_e}")
747
+
748
  with tab3:
749
  st.markdown("""
750
  **🔍 Analysis Process**