duongthienz's picture
Update state.py
1aaacea verified
"""
state.py — Streamlit session-state management and UI callbacks.
Covers:
- init_session_state()
- Speaker rename helpers (get_display_name, apply_speaker_renames_to_df)
- Category callbacks (addCategory, removeCategory, updateCategoryOptions)
- Global rename callbacks (addGlobalRename, removeGlobalRename, applyGlobalRenames)
- File-switch callback (updateMultiSelect)
- analyze() — builds and caches all DataFrames for a single file
- convert_df(), printV()
"""
import copy
import traceback
import pandas as pd
import streamlit as st
import sonogram_utility as su
import utils
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
verbosity = 4 # 0=None 1=Low 2=Medium 3=High 4=Debug
def printV(message, level):
if verbosity >= level:
print(message)
# ---------------------------------------------------------------------------
# Session state initialisation
# ---------------------------------------------------------------------------
def init_session_state():
"""Idempotently initialise every session-state key the app needs."""
defaults = {
"results": {}, # {filename: (annotations, totalSeconds)}
"speakerRenames": {}, # {filename: {speaker: name}}
"summaries": {}, # {filename: {df2, df3, ...}}
"categories": [],
"categorySelect": {}, # {filename: [[], [], ...]}
"removeCategory": None,
"resetResult": False,
"unusedSpeakers": {}, # {filename: [speaker, ...]}
"file_names": [],
"valid_files": [],
"file_paths": {}, # {filename: path}
"showSummary": "No",
"speakerClips": {}, # {filename: {speaker: wav_bytes}}
"speakerSegments": {}, # {filename: {speaker: [(start,end), ...]}}
"speakerWaveforms": {}, # {filename: (waveform_tensor, sample_rate)}
"globalRenames": [], # [{"name": str, "speakers": ["file: SPEAKER_##", ...]}]
"analyzeAllToggle": False,
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
# ---------------------------------------------------------------------------
# Display-name helpers
# ---------------------------------------------------------------------------
def get_display_name(speaker, fileName):
"""Return the user-assigned display name for a speaker, or the original label."""
return st.session_state.speakerRenames.get(fileName, {}).get(speaker, speaker)
def apply_speaker_renames_to_df(df, fileName, column="task"):
"""Replace SPEAKER_## labels in a DataFrame column with display names."""
if column not in df.columns:
return df
df = df.copy()
df[column] = df[column].apply(lambda s: get_display_name(s, fileName))
return df
@st.cache_data
def convert_df(df):
return df.to_csv(index=False).encode("utf-8")
# ---------------------------------------------------------------------------
# Category callbacks
# ---------------------------------------------------------------------------
def addCategory():
new = st.session_state.categoryInput
st.toast(f"Adding {new}")
st.session_state[f"multiselect_{new}"] = []
st.session_state.categories.append(new)
st.session_state.categoryInput = ""
for fname in st.session_state.categorySelect:
st.session_state.categorySelect[fname].append([])
def removeCategory(index):
name = st.session_state.categories[index]
st.toast(f"Removing {name}")
del st.session_state[f"multiselect_{name}"]
del st.session_state[f"remove_{name}"]
del st.session_state.categories[index]
for fname in st.session_state.categorySelect:
del st.session_state.categorySelect[fname][index]
def updateCategoryOptions(fileName):
if st.session_state.resetResult:
return
currAnnotation, _ = st.session_state.results[fileName]
speakerNames = list(currAnnotation.labels())
saved_renames = st.session_state.speakerRenames.get(fileName, {})
display_to_raw = {saved_renames.get(sp, sp): sp for sp in speakerNames}
unusedSpeakers = copy.deepcopy(speakerNames)
for i, category in enumerate(st.session_state.categories):
display_choices = list(st.session_state[f"multiselect_{category}"])
raw_choices = [display_to_raw.get(d, d) for d in display_choices]
st.session_state.categorySelect[fileName][i] = raw_choices
for sp in raw_choices:
try:
unusedSpeakers.remove(sp)
except ValueError:
pass
st.session_state.unusedSpeakers[fileName] = unusedSpeakers
# ---------------------------------------------------------------------------
# Global rename callbacks
# ---------------------------------------------------------------------------
def _global_rename_key(index):
return f"grename_speakers_{index}"
def applyGlobalRenames():
"""Re-write speakerRenames from globalRenames and refresh widget keys."""
for fname in st.session_state.speakerRenames:
st.session_state.speakerRenames[fname] = {}
for entry in st.session_state.globalRenames:
display_name = entry["name"]
for token in entry["speakers"]:
if ": " not in token:
continue
fname, raw_sp = token.split(": ", 1)
if fname in st.session_state.speakerRenames:
st.session_state.speakerRenames[fname][raw_sp] = display_name
curr = st.session_state.get("select_currFile")
if curr and curr in st.session_state.speakerRenames:
saved = st.session_state.speakerRenames[curr]
results = st.session_state.results.get(curr)
if results:
for sp in results[0].labels():
st.session_state[f"rename_{curr}_{sp}"] = saved.get(sp, "")
def addGlobalRename():
new_name = st.session_state.globalRenameInput.strip()
if not new_name:
return
st.toast(f"Adding rename '{new_name}'")
st.session_state.globalRenames.append({"name": new_name, "speakers": []})
st.session_state[_global_rename_key(len(st.session_state.globalRenames) - 1)] = []
st.session_state.globalRenameInput = ""
def removeGlobalRename(index):
entry = st.session_state.globalRenames[index]
st.toast(f"Removing rename '{entry['name']}'")
del st.session_state.globalRenames[index]
for i in range(index, len(st.session_state.globalRenames)):
st.session_state[_global_rename_key(i)] = list(
st.session_state.globalRenames[i]["speakers"]
)
applyGlobalRenames()
# ---------------------------------------------------------------------------
# File-switch callback
# ---------------------------------------------------------------------------
def updateMultiSelect():
fileName = st.session_state["select_currFile"]
st.session_state.resetResult = True
result = st.session_state.results.get(fileName)
if not result:
return
currAnnotation, _ = result
speakerNames = list(currAnnotation.labels())
saved_renames = st.session_state.speakerRenames.get(fileName, {})
raw_to_display = {}
for sp in speakerNames:
saved = saved_renames.get(sp, "")
st.session_state[f"rename_{fileName}_{sp}"] = saved
raw_to_display[sp] = saved if saved else sp
for i, category in enumerate(st.session_state.categories):
raw_choices = st.session_state.categorySelect[fileName][i]
st.session_state[f"multiselect_{category}"] = [
raw_to_display.get(sp, sp) for sp in raw_choices
]
# ---------------------------------------------------------------------------
# Speaker-clip session-state helpers
# ---------------------------------------------------------------------------
def store_speaker_clips(fname, annotations, waveform, sample_rate):
"""Generate clips & segments and write them into session state."""
clips, segments = utils.build_speaker_clips(annotations, waveform, sample_rate)
st.session_state.speakerClips[fname] = clips
st.session_state.speakerSegments[fname] = segments
st.session_state.speakerWaveforms[fname] = (waveform, sample_rate)
print(f"Generated {len(clips)} speaker clips for {fname}")
def randomize_speaker_clip(file_index, speaker):
"""Replace a speaker's clip with a freshly randomized one."""
segs = st.session_state.speakerSegments.get(file_index, {}).get(speaker)
waveform_data = st.session_state.speakerWaveforms.get(file_index)
if not segs or waveform_data is None:
return
waveform, sample_rate = waveform_data
new_clip = utils.get_randomized_clip(waveform, sample_rate, segs)
st.session_state.speakerClips[file_index][speaker] = new_clip
print(f"Randomized clip for {speaker} in {file_index}")
# ---------------------------------------------------------------------------
# Per-file registration helper (keeps Demo / upload code DRY)
# ---------------------------------------------------------------------------
def register_file(fname):
"""Ensure all session-state dicts have an entry for fname."""
st.session_state.results.setdefault(fname, [])
st.session_state.summaries.setdefault(fname, {})
st.session_state.unusedSpeakers.setdefault(fname, [])
st.session_state.categorySelect.setdefault(
fname, [[] for _ in st.session_state.categories]
)
st.session_state.speakerRenames.setdefault(fname, {})
st.session_state.speakerClips.setdefault(fname, {})
if fname not in st.session_state.file_names:
st.session_state.file_names.append(fname)
# ---------------------------------------------------------------------------
# File loading helpers
# ---------------------------------------------------------------------------
def load_annotation_file(fname, fpath):
"""Load an annotation-only file (.txt / .rttm / .csv) into session state."""
ext = fpath.lower()
if ext.endswith(".txt"):
_, annotations = su.loadAudioTXT(fpath)
elif ext.endswith(".rttm"):
_, annotations = su.loadAudioRTTM(fpath)
elif ext.endswith(".csv"):
_, annotations = su.loadAudioCSV(fpath)
else:
raise ValueError(f"Unsupported annotation format: {fpath}")
totalSeconds = max((s.end for s in annotations.itersegments()), default=0)
st.session_state.results[fname] = (annotations, totalSeconds)
st.session_state.summaries[fname] = {}
st.session_state.unusedSpeakers[fname] = list(annotations.labels())
return annotations, totalSeconds
def load_demo_single(demo_path):
"""Register and load a single RTTM demo file, then run analyze()."""
import time
dname = demo_path.split("/")[-1]
register_file(dname)
st.session_state.file_paths[dname] = demo_path
start_time = time.time()
with st.spinner("Loading Demo Sample"):
load_annotation_file(dname, demo_path)
with st.spinner("Analyzing Demo Data"):
analyze(dname)
st.success(f"Took {time.time() - start_time:.1f}s to analyze the demo file!")
st.session_state.select_currFile = dname
return dname
def load_demo_multi(demo_paths):
"""Register and load multiple RTTM demo files."""
for demo_path in demo_paths:
dname = demo_path.split("/")[-1]
register_file(dname)
st.session_state.file_paths[dname] = demo_path
with st.spinner(f"Loading: {dname}"):
load_annotation_file(dname, demo_path)
st.session_state.analyzeAllToggle = True
def run_analysis_loop(file_names, file_paths_dict, pipeline,
enable_denoise, early_cleanup,
gain_window, minimum_gain, maximum_gain,
df_model, df_state, atten_lim_db):
"""Process every file in file_names and populate session state."""
import time
import utils as _utils
start_time = time.time()
totalFiles = len(file_names)
for i, fname in enumerate(file_names):
fpath = file_paths_dict.get(fname, "")
ext = fpath.lower()
if ext.endswith((".txt", ".rttm", ".csv")):
label = ext.rsplit(".", 1)[-1].upper()
with st.spinner(f"Loading {label} {i+1}/{totalFiles}"):
load_annotation_file(fname, fpath)
else:
with st.spinner(f"Processing Audio {i+1}/{totalFiles}"):
annotations, totalSeconds, waveform, sample_rate = _utils.processFile(
fpath, pipeline, enable_denoise, early_cleanup,
gain_window, minimum_gain, maximum_gain,
df_model, df_state, atten_lim_db,
)
st.session_state.results[fname] = (annotations, totalSeconds)
st.session_state.summaries[fname] = {}
st.session_state.unusedSpeakers[fname] = list(annotations.labels())
with st.spinner(f"Generating clips {i+1}/{totalFiles}"):
store_speaker_clips(fname, annotations, waveform, sample_rate)
del waveform
with st.spinner(f"Analyzing {i+1}/{totalFiles}"):
analyze(fname)
st.success(f"Analyzed {totalFiles} file(s) in {time.time() - start_time:.1f}s")
st.session_state.analyzeAllToggle = False
def build_table_df(displayDF):
"""Return a display-only copy of displayDF with cosmetic transforms applied:
- Rename 'Resource' -> 'Speaker'
- Drop 'Task' column if present
- Format Start / Finish as HH:MM:SS.cs strings
"""
def _fmt(val):
try:
secs = float(val)
except (TypeError, ValueError):
return str(val)
h = int(secs // 3600)
m = int(secs % 3600 // 60)
s = int(secs % 60)
cs = round((secs % 1) * 100)
return f"{h:02d}:{m:02d}:{s:02d}.{cs:02d}"
df = displayDF.copy()
if "Task" in df.columns:
df = df.drop(columns=["Task"])
if "Start" in df.columns:
df["Start"] = df["Start"].apply(_fmt)
if "Finish" in df.columns:
df["Finish"] = df["Finish"].apply(_fmt)
return df.rename(columns={"Resource": "Speaker"})
# ---------------------------------------------------------------------------
# analyze() — build and cache all DataFrames for one file
# ---------------------------------------------------------------------------
def analyze(inFileName):
"""Compute and store all summary DataFrames for inFileName."""
try:
printV(f"Start analyzing {inFileName}", 4)
st.session_state.resetResult = False
if not (
inFileName in st.session_state.results
and inFileName in st.session_state.summaries
and len(st.session_state.results[inFileName]) > 0
):
return
currAnnotation, currTotalTime = st.session_state.results[inFileName]
speakerNames = currAnnotation.labels()
categorySelections = st.session_state.categorySelect[inFileName]
printV("Loaded results", 4)
noVoice, oneVoice, multiVoice = su.calcSpeakingTypes(currAnnotation, currTotalTime)
sumNoVoice = su.sumTimes(noVoice)
sumOneVoice = su.sumTimes(oneVoice)
sumMultiVoice = su.sumTimes(multiVoice)
# df3
df3 = utils.build_df3(noVoice, oneVoice, multiVoice)
st.session_state.summaries[inFileName]["df3"] = df3
printV("Set df3", 4)
# df4
df4, nameList, valueList, extraNames, extraValues = utils.build_df4(
speakerNames, categorySelections, st.session_state.categories, currAnnotation
)
st.session_state.summaries[inFileName]["df4"] = df4
printV("Set df4", 4)
# df5
df5 = utils.build_df5(
oneVoice, multiVoice,
sumNoVoice, sumOneVoice, sumMultiVoice,
currTotalTime,
)
st.session_state.summaries[inFileName]["df5"] = df5
printV("Set df5", 4)
# speakers_dataFrame, df2
speakers_dataFrame, speakers_times = su.annotationToDataFrame(currAnnotation)
st.session_state.summaries[inFileName]["speakers_dataFrame"] = speakers_dataFrame
st.session_state.summaries[inFileName]["speakers_times"] = speakers_times
df2 = utils.build_df2(
nameList + extraNames,
valueList + extraValues,
currTotalTime,
)
st.session_state.summaries[inFileName]["df2"] = df2
printV("Set df2", 4)
except Exception as e:
print(f"Error in analyze: {e}")
traceback.print_exc()
st.error(f"Debug - analyze() failed: {e}")