Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from streamlit_extras.stylable_container import stylable_container | |
| import os | |
| import time | |
| import pathlib | |
| from datetime import timedelta | |
| import requests | |
| os.environ['STREAMLIT_SERVER_ENABLE_FILE_WATCHER'] = 'false' | |
| import whisper # openai-whisper | |
| import torch # check for GPU availability | |
| # from models.loader import load_model_sst | |
| from transcriber import Transcription | |
| import matplotlib.colors as mcolors | |
| ###### | |
| # import gdown | |
| # import tempfile | |
| from utils import load_config, get_secret_api | |
| st.session_state.secret_api = get_secret_api() | |
| # # create & close the temp file so it's not locked | |
| # tmp = tempfile.NamedTemporaryFile(delete=False) | |
| # tmp_path = tmp.name | |
| # tmp.close() | |
| # gdown.download(id=load_config()['links']['secret_api_id'], output=tmp_path, quiet=True) | |
| # tmp.seek(0) | |
| # st.session_state.secret_api = tmp.read()#.decode('utf-8') | |
| # os.remove(tmp_path) | |
| # with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| # gdown.download(id=load_config()['links']['secret_api_id'], output=tmp.name, quiet=True) | |
| # tmp.seek(0) | |
| # st.session_state.secret_api = tmp.read().decode('utf-8') | |
| # tmp_path = tmp.name | |
| # tmp.close() | |
| # os.remove(tmp_path) | |
| ###### | |
| trash_str = 'Субтитры создавал DimaTorzok' | |
| st.title('🎙️ Step 2: Speech-to-Text (ASR/STT)') | |
| # Check if audio path exists from previous step | |
| if 'audio_path' not in st.session_state or not st.session_state['audio_path'] or not os.path.exists(st.session_state['audio_path']): | |
| st.warning('Audio file not found. Please go back to the "**📤 Upload**" page and process a video first.') | |
| st.stop() | |
| # st.write(f'Audio file to process: `{os.path.basename(audio_path)}`') | |
| st.write(f'Processing audio `{st.session_state.video_input_title}` from video input') | |
| if 'start_time' not in st.session_state: | |
| st.session_state.start_time = 0 | |
| # st.audio(audio_path) | |
| # format='audio/wav', | |
| st.audio(st.session_state.audio_path, start_time=st.session_state.start_time) | |
| # | |
| # ================================================================== | |
| # | |
| col_model, col_config = st.columns(2) | |
| # --- Model --- | |
| # with col_model.expander('**MODEL**', expanded=True): | |
| with col_model.container(border=True): | |
| model_option = st.selectbox( | |
| 'SST Model:', | |
| ['whisper', 'faster-whisper', 'distill-whisper', 'giga'], | |
| index=0 | |
| ) | |
| # sst_model = load_model_sst(model_option) | |
| # --- Configuration --- | |
| with col_config.expander('**CONFIG**', expanded=True): | |
| # Determine device | |
| default_device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| device = st.radio( | |
| 'Compute device:', | |
| ('cuda', 'cpu'), | |
| index=0 if default_device == 'cuda' else 1, | |
| horizontal=True, | |
| disabled=not torch.cuda.is_available() | |
| ) | |
| if device == 'cuda' and not torch.cuda.is_available(): | |
| st.warning('CUDA selected but not available, falling back to CPU') | |
| device = 'cpu' | |
| whisper_model_option = st.selectbox( | |
| 'Whisper model type:', | |
| ['tiny', 'base', 'small', 'medium', 'large-v3', 'turbo'], | |
| index=5 | |
| ) | |
| pauses = st.checkbox('pauses', value=False) | |
| # from models.models_sst import Whisper | |
| # Whisper.config() | |
| ## | |
| ## --- Transcription --- | |
| ## | |
| _, col_button_trancribe, _ = st.columns([2, 1, 2]) | |
| if col_button_trancribe.button('Transcribe', type='primary', use_container_width=True): | |
| # if input_files: | |
| # pass | |
| # else: | |
| # st.error("Please select a file") | |
| st.session_state.transcript = None # clear previous transcript | |
| col_info, col_complete, col_next = st.columns(3) | |
| try: | |
| with st.spinner(f'Loading Whisper `{whisper_model_option}` model and transcribing..'): | |
| # #-- Load whisper model | |
| # start = time.time() | |
| # # Let Whisper handle device placement if possible | |
| # model = whisper.load_model(whisper_model_option, device=device) | |
| # # load_time = | |
| # col_info.info(f'Model loaded in {time.time() - start:.2f} seconds.') | |
| #-- Perform transcription | |
| start = time.time() | |
| # print('################################') | |
| # print(st.session_state.audio_path) | |
| # print('################################') | |
| # with open(audio_path, "rb") as audio_file: | |
| # transcript = openai.Audio.transcribe("whisper-1", audio_file) | |
| # st.write(st.session_state.secret_api) | |
| # response = requests.post( | |
| # f'{st.session_state.secret_api}/post', | |
| # f'https://535e-104-196-233-103.ngrok-free/transcribe', | |
| # # params={'username': username, 'filename': uploaded_pdf.name}, | |
| # params={'filename': st.session_state.audio_path}, | |
| # # files={'uploaded_file': uploaded_pdf.getvalue()} | |
| # # files={'uploaded_file': whisper.load_audio(st.session_state.audio_path)} | |
| # files={'file': 'string'} | |
| # # json={'1': '2'} | |
| # ) | |
| # st.write(response) | |
| # # import sys | |
| # # st.write(sys.sizeof(f)) | |
| # st.write(response.text) | |
| with open(st.session_state.audio_path, 'rb') as f: | |
| response = requests.post( | |
| # f'{st.session_state.secret_api}/transcribe_faster_whisper', | |
| f'{st.session_state.secret_api}/transcribe', | |
| # params={'filename': st.session_state.audio_path}, | |
| # files={'uploaded_file': uploaded_pdf.getvalue()} | |
| # files={'uploaded_file': whisper.load_audio(st.session_state.audio_path)} | |
| # data={'model': whisper_model_option}, | |
| params={'model': whisper_model_option}, | |
| files={'file': f} | |
| ) | |
| st.write(response) | |
| response = response.json() | |
| # st.write(response['inference_time']) | |
| # st.write(response['model_name']) | |
| # st.write(response['form']) | |
| st.session_state['transcript'] = response['output'] | |
| # st.session_state['transcript'] = result['text'] | |
| st.session_state.transcript = Transcription(st.session_state.audio_path) | |
| # # st.session_state.transcript = Transcription([audio_path]) | |
| # # st.session_state.transcript.transcribe(whisper_model_option) | |
| # # st.markdown(model.name) | |
| # st.session_state.transcript.transcribe(model) | |
| # # result = model.transcribe(audio_path, fp16=(device == 'cuda')) # use fp16 on GPU for speed/memory | |
| st.session_state.transcript.output = response['output'] | |
| transcribe_time = time.time() - start | |
| # st.session_state['transcript'] = result['text'] | |
| # st.session_state['transcript'] = st.session_state.transcript | |
| # Store segments for timestamping/structuring later | |
| # print(len(st.session_state.transcript['segments'])) | |
| # st.session_state['transcript_segments'] = st.session_state.transcript['segments'] | |
| col_complete.success(f'Transcription complete! (Took {transcribe_time:.2f}s)') | |
| col_next.page_link('ui_video.py', label='Next Step: **🖼️ Analyze Video**', icon='➡️') | |
| except Exception as e: | |
| st.error(f'An error occurred during transcription: {e}') | |
| # Consider unloading model if error occurs to free memory | |
| if 'model' in locals(): | |
| del model | |
| if device == 'cuda': | |
| torch.cuda.empty_cache() | |
| if 'transcript' in st.session_state and st.session_state['transcript']: | |
| # --- Video Player --- | |
| with st.expander('**Video Player**', expanded=True): | |
| col_video, col_segments = st.columns(2) | |
| col_video.video(st.session_state.video_path, start_time=st.session_state.start_time) | |
| # --- Display Transcript --- | |
| prev_word_end = -1 | |
| text = '' | |
| html_text = '' | |
| # for idx, segment in st.session_state.transcript.output['segments']: | |
| # if trash_str in segment['text'].strip(): | |
| # st.session_state.transcript.output['segments'][idx] | |
| output = st.session_state.transcript.output | |
| # doc = docx.Document() | |
| avg_confidence_score = 0 | |
| amount_words = 0 | |
| save_dir = str(pathlib.Path(__file__).parent.absolute()) + '/transcripts/' | |
| # st.write(output['segments']) | |
| for idx, segment in enumerate(output['segments']): | |
| # segment[idx] = segment.replace(trash_str, '') | |
| for w in segment['words']: | |
| amount_words += 1 | |
| avg_confidence_score += w['probability'] | |
| # Define the color map | |
| colors = [(0.6, 0, 0), (1, 0.7, 0), (0, 0.6, 0)] | |
| cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors) | |
| with st.expander('**TRANSCRIPT**', expanded=True): | |
| st.badge( | |
| f'whisper model: **`{whisper_model_option}`** | ' + | |
| f'language: **`{output["language"]}`** | ' + | |
| f'confidence score: **`{round(avg_confidence_score / amount_words, 3)}`**' | |
| ) | |
| color_coding = st.checkbox( | |
| 'color coding', | |
| value=True, | |
| # key={i}, | |
| help='Цветное кодирование слов в зависимости от вероятности правильного распознавания: от зелёного (хорошо) до красного (плохо)' | |
| ) | |
| # https://docs.streamlit.io/develop/api-reference/layout/st.container | |
| with st.container(height=300, border=False): | |
| for idx, segment in enumerate(output['segments']): | |
| for w in output['segments'][idx]['words']: | |
| # check for pauses in speech longer than 3s | |
| if pauses and prev_word_end != -1 and w['start'] - prev_word_end >= 3: | |
| pause = w['start'] - prev_word_end | |
| pause_int = int(pause) | |
| html_text += f'{"." * pause_int}{{{pause_int}sec}}' | |
| text += f'{"." * pause_int}{{{pause_int}sec}}' | |
| prev_word_end = w['end'] | |
| if (color_coding): | |
| rgba_color = cmap(w['probability']) | |
| rgb_color = tuple(round(x * 255) | |
| for x in rgba_color[:3]) | |
| else: | |
| rgb_color = (0, 0, 0) | |
| html_text += f"<span style='color:rgb{rgb_color}'>{w['word']}</span>" | |
| text += w['word'] | |
| # insert line break if there is a punctuation mark | |
| if any(c in w['word'] for c in '!?.') and not any(c.isdigit() for c in w['word']): | |
| html_text += '<br><br>' | |
| text += '\n\n' | |
| st.markdown(html_text, unsafe_allow_html=True) | |
| # doc.add_paragraph(text) | |
| # if (translation): | |
| # with st.expander("English translation"): | |
| # st.markdown(output["translation"], unsafe_allow_html=True) | |
| # # save transcript as docx. in local folder | |
| # file_name = output['name'] + "-" + whisper_model + \ | |
| # "-" + datetime.today().strftime('%d-%m-%y') + ".docx" | |
| # doc.save(save_dir + file_name) | |
| # bio = io.BytesIO() | |
| # doc.save(bio) | |
| # st.download_button( | |
| # label="Download Transcription", | |
| # data=bio.getvalue(), | |
| # file_name=file_name, | |
| # mime="docx" | |
| # ) | |
| # --- Display Segments with timestamps --- | |
| # if 'segments' in st.session_state.transcript: | |
| # with st.expander('Detailed segments (with timestamps)'): | |
| # st.json(st.session_state.transcript['segments']) | |
| format_time = lambda s: str(timedelta(seconds=int(s))) | |
| # st.write(st.session_state.transcript.output['segments']) | |
| # https://discuss.streamlit.io/t/replaying-an-audio-file-with-a-timecode-click/48892/9 | |
| # with col_segments.expander('**SEGMENTS**', expanded=True): | |
| # with col_segments.container('**SEGMENTS**', expanded=True): | |
| # https://docs.streamlit.io/develop/api-reference/layout/st.container | |
| st.session_state['transcript_segments'] = '' | |
| with col_segments.container(height=400, border=False): | |
| # Style buttons as links | |
| with stylable_container( | |
| key='link_buttons', | |
| css_styles=''' | |
| button { | |
| background: none!important; | |
| border: none; | |
| padding: 0!important; | |
| font-family: arial, sans-serif; | |
| color: #069; | |
| cursor: pointer; | |
| } | |
| ''', | |
| ): | |
| for i, segment in enumerate(st.session_state.transcript.output['segments']): | |
| start = format_time(segment['start']) | |
| end = format_time(segment['end']) | |
| text = segment['text'].strip() | |
| # 🕒Segment {i + 1} | |
| # st.badge(f'**[{start} - {end}]** {text}', color='gray') | |
| # st.markdown( | |
| # f':violet-badge[**{start} - {end}**] :gray-badge[{text}]' | |
| # ) | |
| col_timecode, col_text = st.columns([1, 5], vertical_alignment='center') | |
| # seg_text = f':violet-badge[**{start} - {end}**] :gray-badge[{text}]' | |
| if col_timecode.button(f':violet-badge[**{start} – {end}**]', use_container_width=True): | |
| st.session_state['start_time'] = start | |
| st.rerun() | |
| # col_text.markdown(f':gray-badge[`{text}`]') | |
| # col_text.write('#') | |
| # col_text.markdown(f'<div style="text-align: bottom;">:gray-badge[{text}]</div>', unsafe_allow_html=True) | |
| st.session_state.transcript_segments += f'[**{start} – {end}**] {text}' | |
| col_text.text(f'{text}') | |
| # col_text.badge(text, color='gray') | |
| if trash_str in st.session_state.transcript_segments: | |
| st.session_state.transcript_segments.replace(trash_str, '') | |
| # else: | |
| # st.info('Transcript has not been generated yet.') | |