Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
182c9ce
1
Parent(s):
b1b7e3c
(FIX): Streamline remove redundant scripts.preprocess_dataset import; enhance resampling logic with diagnosis
Browse files- Refactors spectrum resampling + improve diagnostics
- Introduces robust checks for strictly increasing sequences in resampling results, logging ambigous cases
- Adds session state persistence for both raw and resampled data, and enriches diagnostics with detailed statistics about the resampled data
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from models.resnet_cnn import ResNet1D
|
2 |
from models.figure2_cnn import Figure2CNN
|
3 |
-
import logging
|
4 |
import hashlib
|
5 |
import gc
|
6 |
import time
|
@@ -22,14 +21,7 @@ if utils_path.is_dir() and str(utils_path) not in sys.path:
|
|
22 |
matplotlib.use("Agg") # ensure headless rendering in Spaces
|
23 |
|
24 |
# Import local modules
|
25 |
-
|
26 |
-
try:
|
27 |
-
from scripts.preprocess_dataset import resample_spectrum
|
28 |
-
except (ImportError, ModuleNotFoundError):
|
29 |
-
try:
|
30 |
-
from utils.preprocessing import resample_spectrum
|
31 |
-
except (ImportError, ModuleNotFoundError):
|
32 |
-
raise ImportError("Could not import 'resample_spectrum' from either 'scripts.preprocess_dataset' or 'utils.preprocessing'. Please ensure the function exists in one of these modules.")
|
33 |
|
34 |
KEEP_KEYS = {
|
35 |
# === global UI context we want to keep after "Reset" ===
|
@@ -129,7 +121,7 @@ def label_file(filename: str) -> int:
|
|
129 |
def load_state_dict(_mtime, model_path):
|
130 |
"""Load state dict with mtime in cache key to detect file changes"""
|
131 |
try:
|
132 |
-
return torch.load(model_path, map_location="cpu")
|
133 |
except (FileNotFoundError, RuntimeError) as e:
|
134 |
st.warning(f"Error loading state dict: {e}")
|
135 |
return None
|
@@ -235,11 +227,11 @@ def parse_spectrum_data(raw_text):
|
|
235 |
return x, y
|
236 |
|
237 |
|
238 |
-
def create_spectrum_plot(x_raw, y_raw, y_resampled):
|
239 |
"""Create spectrum visualization plot"""
|
240 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
241 |
|
242 |
-
# Raw spectrum
|
243 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
244 |
ax[0].set_title("Raw Input Spectrum")
|
245 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
@@ -247,19 +239,16 @@ def create_spectrum_plot(x_raw, y_raw, y_resampled):
|
|
247 |
ax[0].grid(True, alpha=0.3)
|
248 |
ax[0].legend()
|
249 |
|
250 |
-
# Resampled spectrum
|
251 |
-
x_resampled =
|
252 |
-
ax[1].
|
253 |
-
color="steelblue", linewidth=1)
|
254 |
-
ax[1].set_title(f"Resampled ({TARGET_LEN} points)")
|
255 |
ax[1].set_xlabel("Wavenumber (cm⁻¹)")
|
256 |
ax[1].set_ylabel("Intensity")
|
257 |
ax[1].grid(True, alpha=0.3)
|
258 |
ax[1].legend()
|
259 |
|
260 |
plt.tight_layout()
|
261 |
-
|
262 |
-
# Convert to image
|
263 |
buf = io.BytesIO()
|
264 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
|
265 |
buf.seek(0)
|
@@ -546,7 +535,30 @@ def main():
|
|
546 |
|
547 |
# Resample
|
548 |
with st.spinner("Resampling spectrum..."):
|
549 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
# Persist results (drives right column)
|
552 |
st.session_state["x_raw"] = x_raw
|
@@ -571,6 +583,7 @@ def main():
|
|
571 |
# Get data from session state
|
572 |
x_raw = st.session_state.get('x_raw')
|
573 |
y_raw = st.session_state.get('y_raw')
|
|
|
574 |
y_resampled = st.session_state.get('y_resampled')
|
575 |
filename = st.session_state.get('filename', 'Unknown')
|
576 |
|
@@ -578,8 +591,7 @@ def main():
|
|
578 |
|
579 |
# Create and display plot
|
580 |
try:
|
581 |
-
spectrum_plot = create_spectrum_plot(
|
582 |
-
x_raw, y_raw, y_resampled)
|
583 |
st.image(
|
584 |
spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
|
585 |
except (ValueError, RuntimeError, TypeError) as e:
|
@@ -706,6 +718,33 @@ def main():
|
|
706 |
st.text_area("Logs", "\n".join(
|
707 |
st.session_state.get("log_messages", [])), height=200)
|
708 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
709 |
with tab3:
|
710 |
st.markdown("""
|
711 |
**🔍 Analysis Process**
|
|
|
1 |
from models.resnet_cnn import ResNet1D
|
2 |
from models.figure2_cnn import Figure2CNN
|
|
|
3 |
import hashlib
|
4 |
import gc
|
5 |
import time
|
|
|
21 |
matplotlib.use("Agg") # ensure headless rendering in Spaces
|
22 |
|
23 |
# Import local modules
|
24 |
+
from utils.preprocessing import resample_spectrum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
KEEP_KEYS = {
|
27 |
# === global UI context we want to keep after "Reset" ===
|
|
|
121 |
def load_state_dict(_mtime, model_path):
|
122 |
"""Load state dict with mtime in cache key to detect file changes"""
|
123 |
try:
|
124 |
+
return torch.load(model_path, map_location="cpu", weights_only=True)
|
125 |
except (FileNotFoundError, RuntimeError) as e:
|
126 |
st.warning(f"Error loading state dict: {e}")
|
127 |
return None
|
|
|
227 |
return x, y
|
228 |
|
229 |
|
230 |
+
def create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled):
|
231 |
"""Create spectrum visualization plot"""
|
232 |
fig, ax = plt.subplots(1, 2, figsize=(13, 5), dpi=100)
|
233 |
|
234 |
+
# == Raw spectrum ==
|
235 |
ax[0].plot(x_raw, y_raw, label="Raw", color="dimgray", linewidth=1)
|
236 |
ax[0].set_title("Raw Input Spectrum")
|
237 |
ax[0].set_xlabel("Wavenumber (cm⁻¹)")
|
|
|
239 |
ax[0].grid(True, alpha=0.3)
|
240 |
ax[0].legend()
|
241 |
|
242 |
+
# == Resampled spectrum ==
|
243 |
+
ax[1].plot(x_resampled, y_resampled, label="Resampled", color="steelblue", linewidth=1)
|
244 |
+
ax[1].set_title(f"Resampled ({len(y_resampled)} points)")
|
|
|
|
|
245 |
ax[1].set_xlabel("Wavenumber (cm⁻¹)")
|
246 |
ax[1].set_ylabel("Intensity")
|
247 |
ax[1].grid(True, alpha=0.3)
|
248 |
ax[1].legend()
|
249 |
|
250 |
plt.tight_layout()
|
251 |
+
# == Convert to image ==
|
|
|
252 |
buf = io.BytesIO()
|
253 |
plt.savefig(buf, format='png', bbox_inches='tight', dpi=100)
|
254 |
buf.seek(0)
|
|
|
535 |
|
536 |
# Resample
|
537 |
with st.spinner("Resampling spectrum..."):
|
538 |
+
# ===Resample Unpack===
|
539 |
+
r1, r2 = resample_spectrum(x_raw, y_raw, TARGET_LEN)
|
540 |
+
|
541 |
+
def _is_strictly_increasing(a):
|
542 |
+
try:
|
543 |
+
a = np.asarray(a)
|
544 |
+
return a.ndim == 1 and a.size >= 2 and np.all(np.diff(a) > 0)
|
545 |
+
except Exception:
|
546 |
+
return False
|
547 |
+
|
548 |
+
if _is_strictly_increasing(r1) and not _is_strictly_increasing(r2):
|
549 |
+
x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
|
550 |
+
elif _is_strictly_increasing(r2) and not _is_strictly_increasing(r1):
|
551 |
+
x_resampled, y_resampled = np.asarray(r2), np.asarray(r1)
|
552 |
+
else:
|
553 |
+
# == Ambigous; assume (x, y) and log
|
554 |
+
x_resampled, y_resampled = np.asarray(r1), np.asarray(r2)
|
555 |
+
log_message("Resample outputs ambigous; assumed (x, y).")
|
556 |
+
|
557 |
+
# ===Persists for plotting + inference===
|
558 |
+
st.session_state["x_raw"] = x_raw
|
559 |
+
st.session_state["y_raw"] = y_raw
|
560 |
+
st.session_state["x_resampled"] = x_resampled # ←-- NEW
|
561 |
+
st.session_state["y_resampled"] = y_resampled
|
562 |
|
563 |
# Persist results (drives right column)
|
564 |
st.session_state["x_raw"] = x_raw
|
|
|
583 |
# Get data from session state
|
584 |
x_raw = st.session_state.get('x_raw')
|
585 |
y_raw = st.session_state.get('y_raw')
|
586 |
+
x_resampled = st.session_state.get('x_resampled') # ← NEW
|
587 |
y_resampled = st.session_state.get('y_resampled')
|
588 |
filename = st.session_state.get('filename', 'Unknown')
|
589 |
|
|
|
591 |
|
592 |
# Create and display plot
|
593 |
try:
|
594 |
+
spectrum_plot = create_spectrum_plot(x_raw, y_raw, x_resampled, y_resampled)
|
|
|
595 |
st.image(
|
596 |
spectrum_plot, caption="Spectrum Preprocessing Results", use_container_width=True)
|
597 |
except (ValueError, RuntimeError, TypeError) as e:
|
|
|
718 |
st.text_area("Logs", "\n".join(
|
719 |
st.session_state.get("log_messages", [])), height=200)
|
720 |
|
721 |
+
try:
|
722 |
+
resampler_mod = getattr(resample_spectrum, "__module__", "unknown")
|
723 |
+
resampler_doc = getattr(resample_spectrum, "__doc__", None)
|
724 |
+
resampler_doc = resampler_doc.splitlines()[0] if isinstance(resampler_doc, str) and resampler_doc else "no doc"
|
725 |
+
|
726 |
+
y_rs = st.session_state.get("y_resampled", None)
|
727 |
+
diag = {}
|
728 |
+
if y_rs is not None:
|
729 |
+
arr = np.asarray(y_rs)
|
730 |
+
diag = {
|
731 |
+
"y_resampled_len": int(arr.size),
|
732 |
+
"y_resampled_min": float(np.min(arr)) if arr.size else None,
|
733 |
+
"y_resampled_max": float(np.max(arr)) if arr.size else None,
|
734 |
+
"y_resampled_ptp": float(np.ptp(arr)) if arr.size else None,
|
735 |
+
"y_resampled_unique": int(np.unique(arr).size) if arr.size else None,
|
736 |
+
"y_resampled_all_equal": bool(np.ptp(arr) == 0.0) if arr.size else None,
|
737 |
+
}
|
738 |
+
|
739 |
+
st.markdown("**Resampler Info")
|
740 |
+
st.json({
|
741 |
+
"module": resampler_mod,
|
742 |
+
"doc": resampler_doc,
|
743 |
+
**({"y_resampled_stats": diag} if diag else {})
|
744 |
+
})
|
745 |
+
except Exception as _e:
|
746 |
+
st.warning(f"Diagnostics skipped: {_e}")
|
747 |
+
|
748 |
with tab3:
|
749 |
st.markdown("""
|
750 |
**🔍 Analysis Process**
|