Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| """ | |
| state.py — Streamlit session-state management and UI callbacks. | |
| Covers: | |
| - init_session_state() | |
| - Speaker rename helpers (get_display_name, apply_speaker_renames_to_df) | |
| - Category callbacks (addCategory, removeCategory, updateCategoryOptions) | |
| - Global rename callbacks (addGlobalRename, removeGlobalRename, applyGlobalRenames) | |
| - File-switch callback (updateMultiSelect) | |
| - analyze() — builds and caches all DataFrames for a single file | |
| - convert_df(), printV() | |
| """ | |
| import copy | |
| import traceback | |
| import pandas as pd | |
| import streamlit as st | |
| import sonogram_utility as su | |
| import utils | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| verbosity = 4 # 0=None 1=Low 2=Medium 3=High 4=Debug | |
| def printV(message, level): | |
| if verbosity >= level: | |
| print(message) | |
| # --------------------------------------------------------------------------- | |
| # Session state initialisation | |
| # --------------------------------------------------------------------------- | |
| def init_session_state(): | |
| """Idempotently initialise every session-state key the app needs.""" | |
| defaults = { | |
| "results": {}, # {filename: (annotations, totalSeconds)} | |
| "speakerRenames": {}, # {filename: {speaker: name}} | |
| "summaries": {}, # {filename: {df2, df3, ...}} | |
| "categories": [], | |
| "categorySelect": {}, # {filename: [[], [], ...]} | |
| "removeCategory": None, | |
| "resetResult": False, | |
| "unusedSpeakers": {}, # {filename: [speaker, ...]} | |
| "file_names": [], | |
| "valid_files": [], | |
| "file_paths": {}, # {filename: path} | |
| "showSummary": "No", | |
| "speakerClips": {}, # {filename: {speaker: wav_bytes}} | |
| "speakerSegments": {}, # {filename: {speaker: [(start,end), ...]}} | |
| "speakerWaveforms": {}, # {filename: (waveform_tensor, sample_rate)} | |
| "globalRenames": [], # [{"name": str, "speakers": ["file: SPEAKER_##", ...]}] | |
| "analyzeAllToggle": False, | |
| } | |
| for key, value in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| # --------------------------------------------------------------------------- | |
| # Display-name helpers | |
| # --------------------------------------------------------------------------- | |
| def get_display_name(speaker, fileName): | |
| """Return the user-assigned display name for a speaker, or the original label.""" | |
| return st.session_state.speakerRenames.get(fileName, {}).get(speaker, speaker) | |
| def apply_speaker_renames_to_df(df, fileName, column="task"): | |
| """Replace SPEAKER_## labels in a DataFrame column with display names.""" | |
| if column not in df.columns: | |
| return df | |
| df = df.copy() | |
| df[column] = df[column].apply(lambda s: get_display_name(s, fileName)) | |
| return df | |
| def convert_df(df): | |
| return df.to_csv(index=False).encode("utf-8") | |
| # --------------------------------------------------------------------------- | |
| # Category callbacks | |
| # --------------------------------------------------------------------------- | |
| def addCategory(): | |
| new = st.session_state.categoryInput | |
| st.toast(f"Adding {new}") | |
| st.session_state[f"multiselect_{new}"] = [] | |
| st.session_state.categories.append(new) | |
| st.session_state.categoryInput = "" | |
| for fname in st.session_state.categorySelect: | |
| st.session_state.categorySelect[fname].append([]) | |
| def removeCategory(index): | |
| name = st.session_state.categories[index] | |
| st.toast(f"Removing {name}") | |
| del st.session_state[f"multiselect_{name}"] | |
| del st.session_state[f"remove_{name}"] | |
| del st.session_state.categories[index] | |
| for fname in st.session_state.categorySelect: | |
| del st.session_state.categorySelect[fname][index] | |
| def updateCategoryOptions(fileName): | |
| if st.session_state.resetResult: | |
| return | |
| currAnnotation, _ = st.session_state.results[fileName] | |
| speakerNames = list(currAnnotation.labels()) | |
| saved_renames = st.session_state.speakerRenames.get(fileName, {}) | |
| display_to_raw = {saved_renames.get(sp, sp): sp for sp in speakerNames} | |
| unusedSpeakers = copy.deepcopy(speakerNames) | |
| for i, category in enumerate(st.session_state.categories): | |
| display_choices = list(st.session_state[f"multiselect_{category}"]) | |
| raw_choices = [display_to_raw.get(d, d) for d in display_choices] | |
| st.session_state.categorySelect[fileName][i] = raw_choices | |
| for sp in raw_choices: | |
| try: | |
| unusedSpeakers.remove(sp) | |
| except ValueError: | |
| pass | |
| st.session_state.unusedSpeakers[fileName] = unusedSpeakers | |
| # --------------------------------------------------------------------------- | |
| # Global rename callbacks | |
| # --------------------------------------------------------------------------- | |
| def _global_rename_key(index): | |
| return f"grename_speakers_{index}" | |
| def applyGlobalRenames(): | |
| """Re-write speakerRenames from globalRenames and refresh widget keys.""" | |
| for fname in st.session_state.speakerRenames: | |
| st.session_state.speakerRenames[fname] = {} | |
| for entry in st.session_state.globalRenames: | |
| display_name = entry["name"] | |
| for token in entry["speakers"]: | |
| if ": " not in token: | |
| continue | |
| fname, raw_sp = token.split(": ", 1) | |
| if fname in st.session_state.speakerRenames: | |
| st.session_state.speakerRenames[fname][raw_sp] = display_name | |
| curr = st.session_state.get("select_currFile") | |
| if curr and curr in st.session_state.speakerRenames: | |
| saved = st.session_state.speakerRenames[curr] | |
| results = st.session_state.results.get(curr) | |
| if results: | |
| for sp in results[0].labels(): | |
| st.session_state[f"rename_{curr}_{sp}"] = saved.get(sp, "") | |
| def addGlobalRename(): | |
| new_name = st.session_state.globalRenameInput.strip() | |
| if not new_name: | |
| return | |
| st.toast(f"Adding rename '{new_name}'") | |
| st.session_state.globalRenames.append({"name": new_name, "speakers": []}) | |
| st.session_state[_global_rename_key(len(st.session_state.globalRenames) - 1)] = [] | |
| st.session_state.globalRenameInput = "" | |
| def removeGlobalRename(index): | |
| entry = st.session_state.globalRenames[index] | |
| st.toast(f"Removing rename '{entry['name']}'") | |
| del st.session_state.globalRenames[index] | |
| for i in range(index, len(st.session_state.globalRenames)): | |
| st.session_state[_global_rename_key(i)] = list( | |
| st.session_state.globalRenames[i]["speakers"] | |
| ) | |
| applyGlobalRenames() | |
| # --------------------------------------------------------------------------- | |
| # File-switch callback | |
| # --------------------------------------------------------------------------- | |
| def updateMultiSelect(): | |
| fileName = st.session_state["select_currFile"] | |
| st.session_state.resetResult = True | |
| result = st.session_state.results.get(fileName) | |
| if not result: | |
| return | |
| currAnnotation, _ = result | |
| speakerNames = list(currAnnotation.labels()) | |
| saved_renames = st.session_state.speakerRenames.get(fileName, {}) | |
| raw_to_display = {} | |
| for sp in speakerNames: | |
| saved = saved_renames.get(sp, "") | |
| st.session_state[f"rename_{fileName}_{sp}"] = saved | |
| raw_to_display[sp] = saved if saved else sp | |
| for i, category in enumerate(st.session_state.categories): | |
| raw_choices = st.session_state.categorySelect[fileName][i] | |
| st.session_state[f"multiselect_{category}"] = [ | |
| raw_to_display.get(sp, sp) for sp in raw_choices | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Speaker-clip session-state helpers | |
| # --------------------------------------------------------------------------- | |
| def store_speaker_clips(fname, annotations, waveform, sample_rate): | |
| """Generate clips & segments and write them into session state.""" | |
| clips, segments = utils.build_speaker_clips(annotations, waveform, sample_rate) | |
| st.session_state.speakerClips[fname] = clips | |
| st.session_state.speakerSegments[fname] = segments | |
| st.session_state.speakerWaveforms[fname] = (waveform, sample_rate) | |
| print(f"Generated {len(clips)} speaker clips for {fname}") | |
| def randomize_speaker_clip(file_index, speaker): | |
| """Replace a speaker's clip with a freshly randomized one.""" | |
| segs = st.session_state.speakerSegments.get(file_index, {}).get(speaker) | |
| waveform_data = st.session_state.speakerWaveforms.get(file_index) | |
| if not segs or waveform_data is None: | |
| return | |
| waveform, sample_rate = waveform_data | |
| new_clip = utils.get_randomized_clip(waveform, sample_rate, segs) | |
| st.session_state.speakerClips[file_index][speaker] = new_clip | |
| print(f"Randomized clip for {speaker} in {file_index}") | |
| # --------------------------------------------------------------------------- | |
| # Per-file registration helper (keeps Demo / upload code DRY) | |
| # --------------------------------------------------------------------------- | |
| def register_file(fname): | |
| """Ensure all session-state dicts have an entry for fname.""" | |
| st.session_state.results.setdefault(fname, []) | |
| st.session_state.summaries.setdefault(fname, {}) | |
| st.session_state.unusedSpeakers.setdefault(fname, []) | |
| st.session_state.categorySelect.setdefault( | |
| fname, [[] for _ in st.session_state.categories] | |
| ) | |
| st.session_state.speakerRenames.setdefault(fname, {}) | |
| st.session_state.speakerClips.setdefault(fname, {}) | |
| if fname not in st.session_state.file_names: | |
| st.session_state.file_names.append(fname) | |
| # --------------------------------------------------------------------------- | |
| # File loading helpers | |
| # --------------------------------------------------------------------------- | |
| def load_annotation_file(fname, fpath): | |
| """Load an annotation-only file (.txt / .rttm / .csv) into session state.""" | |
| ext = fpath.lower() | |
| if ext.endswith(".txt"): | |
| _, annotations = su.loadAudioTXT(fpath) | |
| elif ext.endswith(".rttm"): | |
| _, annotations = su.loadAudioRTTM(fpath) | |
| elif ext.endswith(".csv"): | |
| _, annotations = su.loadAudioCSV(fpath) | |
| else: | |
| raise ValueError(f"Unsupported annotation format: {fpath}") | |
| totalSeconds = max((s.end for s in annotations.itersegments()), default=0) | |
| st.session_state.results[fname] = (annotations, totalSeconds) | |
| st.session_state.summaries[fname] = {} | |
| st.session_state.unusedSpeakers[fname] = list(annotations.labels()) | |
| return annotations, totalSeconds | |
| def load_demo_single(demo_path): | |
| """Register and load a single RTTM demo file, then run analyze().""" | |
| import time | |
| dname = demo_path.split("/")[-1] | |
| register_file(dname) | |
| st.session_state.file_paths[dname] = demo_path | |
| start_time = time.time() | |
| with st.spinner("Loading Demo Sample"): | |
| load_annotation_file(dname, demo_path) | |
| with st.spinner("Analyzing Demo Data"): | |
| analyze(dname) | |
| st.success(f"Took {time.time() - start_time:.1f}s to analyze the demo file!") | |
| st.session_state.select_currFile = dname | |
| return dname | |
| def load_demo_multi(demo_paths): | |
| """Register and load multiple RTTM demo files.""" | |
| for demo_path in demo_paths: | |
| dname = demo_path.split("/")[-1] | |
| register_file(dname) | |
| st.session_state.file_paths[dname] = demo_path | |
| with st.spinner(f"Loading: {dname}"): | |
| load_annotation_file(dname, demo_path) | |
| st.session_state.analyzeAllToggle = True | |
| def run_analysis_loop(file_names, file_paths_dict, pipeline, | |
| enable_denoise, early_cleanup, | |
| gain_window, minimum_gain, maximum_gain, | |
| df_model, df_state, atten_lim_db): | |
| """Process every file in file_names and populate session state.""" | |
| import time | |
| import utils as _utils | |
| start_time = time.time() | |
| totalFiles = len(file_names) | |
| for i, fname in enumerate(file_names): | |
| fpath = file_paths_dict.get(fname, "") | |
| ext = fpath.lower() | |
| if ext.endswith((".txt", ".rttm", ".csv")): | |
| label = ext.rsplit(".", 1)[-1].upper() | |
| with st.spinner(f"Loading {label} {i+1}/{totalFiles}"): | |
| load_annotation_file(fname, fpath) | |
| else: | |
| with st.spinner(f"Processing Audio {i+1}/{totalFiles}"): | |
| annotations, totalSeconds, waveform, sample_rate = _utils.processFile( | |
| fpath, pipeline, enable_denoise, early_cleanup, | |
| gain_window, minimum_gain, maximum_gain, | |
| df_model, df_state, atten_lim_db, | |
| ) | |
| st.session_state.results[fname] = (annotations, totalSeconds) | |
| st.session_state.summaries[fname] = {} | |
| st.session_state.unusedSpeakers[fname] = list(annotations.labels()) | |
| with st.spinner(f"Generating clips {i+1}/{totalFiles}"): | |
| store_speaker_clips(fname, annotations, waveform, sample_rate) | |
| del waveform | |
| with st.spinner(f"Analyzing {i+1}/{totalFiles}"): | |
| analyze(fname) | |
| st.success(f"Analyzed {totalFiles} file(s) in {time.time() - start_time:.1f}s") | |
| st.session_state.analyzeAllToggle = False | |
| def build_table_df(displayDF): | |
| """Return a display-only copy of displayDF with cosmetic transforms applied: | |
| - Rename 'Resource' -> 'Speaker' | |
| - Drop 'Task' column if present | |
| - Format Start / Finish as HH:MM:SS.cs strings | |
| """ | |
| def _fmt(val): | |
| try: | |
| secs = float(val) | |
| except (TypeError, ValueError): | |
| return str(val) | |
| h = int(secs // 3600) | |
| m = int(secs % 3600 // 60) | |
| s = int(secs % 60) | |
| cs = round((secs % 1) * 100) | |
| return f"{h:02d}:{m:02d}:{s:02d}.{cs:02d}" | |
| df = displayDF.copy() | |
| if "Task" in df.columns: | |
| df = df.drop(columns=["Task"]) | |
| if "Start" in df.columns: | |
| df["Start"] = df["Start"].apply(_fmt) | |
| if "Finish" in df.columns: | |
| df["Finish"] = df["Finish"].apply(_fmt) | |
| return df.rename(columns={"Resource": "Speaker"}) | |
| # --------------------------------------------------------------------------- | |
| # analyze() — build and cache all DataFrames for one file | |
| # --------------------------------------------------------------------------- | |
| def analyze(inFileName): | |
| """Compute and store all summary DataFrames for inFileName.""" | |
| try: | |
| printV(f"Start analyzing {inFileName}", 4) | |
| st.session_state.resetResult = False | |
| if not ( | |
| inFileName in st.session_state.results | |
| and inFileName in st.session_state.summaries | |
| and len(st.session_state.results[inFileName]) > 0 | |
| ): | |
| return | |
| currAnnotation, currTotalTime = st.session_state.results[inFileName] | |
| speakerNames = currAnnotation.labels() | |
| categorySelections = st.session_state.categorySelect[inFileName] | |
| printV("Loaded results", 4) | |
| noVoice, oneVoice, multiVoice = su.calcSpeakingTypes(currAnnotation, currTotalTime) | |
| sumNoVoice = su.sumTimes(noVoice) | |
| sumOneVoice = su.sumTimes(oneVoice) | |
| sumMultiVoice = su.sumTimes(multiVoice) | |
| # df3 | |
| df3 = utils.build_df3(noVoice, oneVoice, multiVoice) | |
| st.session_state.summaries[inFileName]["df3"] = df3 | |
| printV("Set df3", 4) | |
| # df4 | |
| df4, nameList, valueList, extraNames, extraValues = utils.build_df4( | |
| speakerNames, categorySelections, st.session_state.categories, currAnnotation | |
| ) | |
| st.session_state.summaries[inFileName]["df4"] = df4 | |
| printV("Set df4", 4) | |
| # df5 | |
| df5 = utils.build_df5( | |
| oneVoice, multiVoice, | |
| sumNoVoice, sumOneVoice, sumMultiVoice, | |
| currTotalTime, | |
| ) | |
| st.session_state.summaries[inFileName]["df5"] = df5 | |
| printV("Set df5", 4) | |
| # speakers_dataFrame, df2 | |
| speakers_dataFrame, speakers_times = su.annotationToDataFrame(currAnnotation) | |
| st.session_state.summaries[inFileName]["speakers_dataFrame"] = speakers_dataFrame | |
| st.session_state.summaries[inFileName]["speakers_times"] = speakers_times | |
| df2 = utils.build_df2( | |
| nameList + extraNames, | |
| valueList + extraValues, | |
| currTotalTime, | |
| ) | |
| st.session_state.summaries[inFileName]["df2"] = df2 | |
| printV("Set df2", 4) | |
| except Exception as e: | |
| print(f"Error in analyze: {e}") | |
| traceback.print_exc() | |
| st.error(f"Debug - analyze() failed: {e}") |