Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_extras.stylable_container import stylable_container | |
import os | |
import time | |
import pathlib | |
from datetime import timedelta | |
import requests | |
os.environ['STREAMLIT_SERVER_ENABLE_FILE_WATCHER'] = 'false' | |
import whisper # openai-whisper | |
import torch # check for GPU availability | |
# from models.loader import load_model_sst | |
from transcriber import Transcription | |
import matplotlib.colors as mcolors | |
###### | |
# import gdown | |
# import tempfile | |
from utils import load_config, get_secret_api | |
st.session_state.secret_api = get_secret_api() | |
# # create & close the temp file so it's not locked | |
# tmp = tempfile.NamedTemporaryFile(delete=False) | |
# tmp_path = tmp.name | |
# tmp.close() | |
# gdown.download(id=load_config()['links']['secret_api_id'], output=tmp_path, quiet=True) | |
# tmp.seek(0) | |
# st.session_state.secret_api = tmp.read()#.decode('utf-8') | |
# os.remove(tmp_path) | |
# with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
# gdown.download(id=load_config()['links']['secret_api_id'], output=tmp.name, quiet=True) | |
# tmp.seek(0) | |
# st.session_state.secret_api = tmp.read().decode('utf-8') | |
# tmp_path = tmp.name | |
# tmp.close() | |
# os.remove(tmp_path) | |
###### | |
trash_str = 'Субтитры создавал DimaTorzok' | |
st.title('🎙️ Step 2: Speech-to-Text (ASR/STT)') | |
# Check if audio path exists from previous step | |
if 'audio_path' not in st.session_state or not st.session_state['audio_path'] or not os.path.exists(st.session_state['audio_path']): | |
st.warning('Audio file not found. Please go back to the "**📤 Upload**" page and process a video first.') | |
st.stop() | |
# st.write(f'Audio file to process: `{os.path.basename(audio_path)}`') | |
st.write(f'Processing audio `{st.session_state.video_input_title}` from video input') | |
if 'start_time' not in st.session_state: | |
st.session_state.start_time = 0 | |
# st.audio(audio_path) | |
# format='audio/wav', | |
st.audio(st.session_state.audio_path, start_time=st.session_state.start_time) | |
# | |
# ================================================================== | |
# | |
col_model, col_config = st.columns(2) | |
# --- Model --- | |
# with col_model.expander('**MODEL**', expanded=True): | |
with col_model.container(border=True): | |
model_option = st.selectbox( | |
'SST Model:', | |
['whisper', 'faster-whisper', 'distill-whisper', 'giga'], | |
index=0 | |
) | |
# sst_model = load_model_sst(model_option) | |
# --- Configuration --- | |
with col_config.expander('**CONFIG**', expanded=True): | |
# Determine device | |
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
device = st.radio( | |
'Compute device:', | |
('cuda', 'cpu'), | |
index=0 if default_device == 'cuda' else 1, | |
horizontal=True, | |
disabled=not torch.cuda.is_available() | |
) | |
if device == 'cuda' and not torch.cuda.is_available(): | |
st.warning('CUDA selected but not available, falling back to CPU') | |
device = 'cpu' | |
whisper_model_option = st.selectbox( | |
'Whisper model type:', | |
['tiny', 'base', 'small', 'medium', 'large-v3', 'turbo'], | |
index=5 | |
) | |
pauses = st.checkbox('pauses', value=False) | |
# from models.models_sst import Whisper | |
# Whisper.config() | |
## | |
## --- Transcription --- | |
## | |
_, col_button_trancribe, _ = st.columns([2, 1, 2]) | |
if col_button_trancribe.button('Transcribe', type='primary', use_container_width=True): | |
# if input_files: | |
# pass | |
# else: | |
# st.error("Please select a file") | |
st.session_state.transcript = None # clear previous transcript | |
col_info, col_complete, col_next = st.columns(3) | |
try: | |
with st.spinner(f'Loading Whisper `{whisper_model_option}` model and transcribing..'): | |
# #-- Load whisper model | |
# start = time.time() | |
# # Let Whisper handle device placement if possible | |
# model = whisper.load_model(whisper_model_option, device=device) | |
# # load_time = | |
# col_info.info(f'Model loaded in {time.time() - start:.2f} seconds.') | |
#-- Perform transcription | |
start = time.time() | |
# print('################################') | |
# print(st.session_state.audio_path) | |
# print('################################') | |
# with open(audio_path, "rb") as audio_file: | |
# transcript = openai.Audio.transcribe("whisper-1", audio_file) | |
# st.write(st.session_state.secret_api) | |
# response = requests.post( | |
# f'{st.session_state.secret_api}/post', | |
# f'https://535e-104-196-233-103.ngrok-free/transcribe', | |
# # params={'username': username, 'filename': uploaded_pdf.name}, | |
# params={'filename': st.session_state.audio_path}, | |
# # files={'uploaded_file': uploaded_pdf.getvalue()} | |
# # files={'uploaded_file': whisper.load_audio(st.session_state.audio_path)} | |
# files={'file': 'string'} | |
# # json={'1': '2'} | |
# ) | |
# st.write(response) | |
# # import sys | |
# # st.write(sys.sizeof(f)) | |
# st.write(response.text) | |
with open(st.session_state.audio_path, 'rb') as f: | |
response = requests.post( | |
# f'{st.session_state.secret_api}/transcribe_faster_whisper', | |
f'{st.session_state.secret_api}/transcribe', | |
# params={'filename': st.session_state.audio_path}, | |
# files={'uploaded_file': uploaded_pdf.getvalue()} | |
# files={'uploaded_file': whisper.load_audio(st.session_state.audio_path)} | |
# data={'model': whisper_model_option}, | |
params={'model': whisper_model_option}, | |
files={'file': f} | |
) | |
st.write(response) | |
response = response.json() | |
# st.write(response['inference_time']) | |
# st.write(response['model_name']) | |
# st.write(response['form']) | |
st.session_state['transcript'] = response['output'] | |
# st.session_state['transcript'] = result['text'] | |
st.session_state.transcript = Transcription(st.session_state.audio_path) | |
# # st.session_state.transcript = Transcription([audio_path]) | |
# # st.session_state.transcript.transcribe(whisper_model_option) | |
# # st.markdown(model.name) | |
# st.session_state.transcript.transcribe(model) | |
# # result = model.transcribe(audio_path, fp16=(device == 'cuda')) # use fp16 on GPU for speed/memory | |
st.session_state.transcript.output = response['output'] | |
transcribe_time = time.time() - start | |
# st.session_state['transcript'] = result['text'] | |
# st.session_state['transcript'] = st.session_state.transcript | |
# Store segments for timestamping/structuring later | |
# print(len(st.session_state.transcript['segments'])) | |
# st.session_state['transcript_segments'] = st.session_state.transcript['segments'] | |
col_complete.success(f'Transcription complete! (Took {transcribe_time:.2f}s)') | |
col_next.page_link('ui_video.py', label='Next Step: **🖼️ Analyze Video**', icon='➡️') | |
except Exception as e: | |
st.error(f'An error occurred during transcription: {e}') | |
# Consider unloading model if error occurs to free memory | |
if 'model' in locals(): | |
del model | |
if device == 'cuda': | |
torch.cuda.empty_cache() | |
if 'transcript' in st.session_state and st.session_state['transcript']: | |
# --- Video Player --- | |
with st.expander('**Video Player**', expanded=True): | |
col_video, col_segments = st.columns(2) | |
col_video.video(st.session_state.video_path, start_time=st.session_state.start_time) | |
# --- Display Transcript --- | |
prev_word_end = -1 | |
text = '' | |
html_text = '' | |
# for idx, segment in st.session_state.transcript.output['segments']: | |
# if trash_str in segment['text'].strip(): | |
# st.session_state.transcript.output['segments'][idx] | |
output = st.session_state.transcript.output | |
# doc = docx.Document() | |
avg_confidence_score = 0 | |
amount_words = 0 | |
save_dir = str(pathlib.Path(__file__).parent.absolute()) + '/transcripts/' | |
# st.write(output['segments']) | |
for idx, segment in enumerate(output['segments']): | |
# segment[idx] = segment.replace(trash_str, '') | |
for w in segment['words']: | |
amount_words += 1 | |
avg_confidence_score += w['probability'] | |
# Define the color map | |
colors = [(0.6, 0, 0), (1, 0.7, 0), (0, 0.6, 0)] | |
cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors) | |
with st.expander('**TRANSCRIPT**', expanded=True): | |
st.badge( | |
f'whisper model: **`{whisper_model_option}`** | ' + | |
f'language: **`{output["language"]}`** | ' + | |
f'confidence score: **`{round(avg_confidence_score / amount_words, 3)}`**' | |
) | |
color_coding = st.checkbox( | |
'color coding', | |
value=True, | |
# key={i}, | |
help='Цветное кодирование слов в зависимости от вероятности правильного распознавания: от зелёного (хорошо) до красного (плохо)' | |
) | |
# https://docs.streamlit.io/develop/api-reference/layout/st.container | |
with st.container(height=300, border=False): | |
for idx, segment in enumerate(output['segments']): | |
for w in output['segments'][idx]['words']: | |
# check for pauses in speech longer than 3s | |
if pauses and prev_word_end != -1 and w['start'] - prev_word_end >= 3: | |
pause = w['start'] - prev_word_end | |
pause_int = int(pause) | |
html_text += f'{"." * pause_int}{{{pause_int}sec}}' | |
text += f'{"." * pause_int}{{{pause_int}sec}}' | |
prev_word_end = w['end'] | |
if (color_coding): | |
rgba_color = cmap(w['probability']) | |
rgb_color = tuple(round(x * 255) | |
for x in rgba_color[:3]) | |
else: | |
rgb_color = (0, 0, 0) | |
html_text += f"<span style='color:rgb{rgb_color}'>{w['word']}</span>" | |
text += w['word'] | |
# insert line break if there is a punctuation mark | |
if any(c in w['word'] for c in '!?.') and not any(c.isdigit() for c in w['word']): | |
html_text += '<br><br>' | |
text += '\n\n' | |
st.markdown(html_text, unsafe_allow_html=True) | |
# doc.add_paragraph(text) | |
# if (translation): | |
# with st.expander("English translation"): | |
# st.markdown(output["translation"], unsafe_allow_html=True) | |
# # save transcript as docx. in local folder | |
# file_name = output['name'] + "-" + whisper_model + \ | |
# "-" + datetime.today().strftime('%d-%m-%y') + ".docx" | |
# doc.save(save_dir + file_name) | |
# bio = io.BytesIO() | |
# doc.save(bio) | |
# st.download_button( | |
# label="Download Transcription", | |
# data=bio.getvalue(), | |
# file_name=file_name, | |
# mime="docx" | |
# ) | |
# --- Display Segments with timestamps --- | |
# if 'segments' in st.session_state.transcript: | |
# with st.expander('Detailed segments (with timestamps)'): | |
# st.json(st.session_state.transcript['segments']) | |
format_time = lambda s: str(timedelta(seconds=int(s))) | |
# st.write(st.session_state.transcript.output['segments']) | |
# https://discuss.streamlit.io/t/replaying-an-audio-file-with-a-timecode-click/48892/9 | |
# with col_segments.expander('**SEGMENTS**', expanded=True): | |
# with col_segments.container('**SEGMENTS**', expanded=True): | |
# https://docs.streamlit.io/develop/api-reference/layout/st.container | |
st.session_state['transcript_segments'] = '' | |
with col_segments.container(height=400, border=False): | |
# Style buttons as links | |
with stylable_container( | |
key='link_buttons', | |
css_styles=''' | |
button { | |
background: none!important; | |
border: none; | |
padding: 0!important; | |
font-family: arial, sans-serif; | |
color: #069; | |
cursor: pointer; | |
} | |
''', | |
): | |
for i, segment in enumerate(st.session_state.transcript.output['segments']): | |
start = format_time(segment['start']) | |
end = format_time(segment['end']) | |
text = segment['text'].strip() | |
# 🕒Segment {i + 1} | |
# st.badge(f'**[{start} - {end}]** {text}', color='gray') | |
# st.markdown( | |
# f':violet-badge[**{start} - {end}**] :gray-badge[{text}]' | |
# ) | |
col_timecode, col_text = st.columns([1, 5], vertical_alignment='center') | |
# seg_text = f':violet-badge[**{start} - {end}**] :gray-badge[{text}]' | |
if col_timecode.button(f':violet-badge[**{start} – {end}**]', use_container_width=True): | |
st.session_state['start_time'] = start | |
st.rerun() | |
# col_text.markdown(f':gray-badge[`{text}`]') | |
# col_text.write('#') | |
# col_text.markdown(f'<div style="text-align: bottom;">:gray-badge[{text}]</div>', unsafe_allow_html=True) | |
st.session_state.transcript_segments += f'[**{start} – {end}**] {text}' | |
col_text.text(f'{text}') | |
# col_text.badge(text, color='gray') | |
if trash_str in st.session_state.transcript_segments: | |
st.session_state.transcript_segments.replace(trash_str, '') | |
# else: | |
# st.info('Transcript has not been generated yet.') | |