|
import requests |
|
import base64 |
|
import os |
|
import json |
|
import streamlit as st |
|
import whisperx |
|
import torch |
|
|
|
def convert_segments_object_to_text(data): |
|
result = [] |
|
|
|
for segment in data['segments']: |
|
words = segment['words'] |
|
segment_speaker = segment.get('speaker', None) |
|
segment_start = segment.get('start', None) |
|
segment_end = segment.get('end', None) |
|
current_speaker = None |
|
current_start = None |
|
current_end = None |
|
current_text = [] |
|
|
|
|
|
for i, word_info in enumerate(words): |
|
if 'speaker' not in word_info: |
|
if i > 0 and 'speaker' in words[i - 1]: |
|
word_info['speaker'] = words[i - 1]['speaker'] |
|
elif i < len(words) - 1 and 'speaker' in words[i + 1]: |
|
word_info['speaker'] = words[i + 1]['speaker'] |
|
else: |
|
word_info['speaker'] = segment_speaker |
|
|
|
if 'start' not in word_info: |
|
if i > 0 and 'end' in words[i - 1]: |
|
word_info['start'] = words[i - 1]['end'] |
|
else: |
|
word_info['start'] = segment_start |
|
|
|
if 'end' not in word_info: |
|
if i < len(words) - 1 and 'start' in words[i + 1]: |
|
word_info['end'] = words[i + 1]['start'] |
|
elif i == len(words) - 1: |
|
word_info['end'] = segment_end |
|
else: |
|
word_info['end'] = word_info['start'] |
|
|
|
for word_info in words: |
|
word = word_info.get('word', '') |
|
start = word_info.get('start', None) |
|
end = word_info.get('end', None) |
|
speaker = word_info.get('speaker', None) |
|
|
|
if current_speaker is None: |
|
current_speaker = speaker |
|
current_start = start |
|
|
|
if speaker == current_speaker: |
|
current_text.append(word) |
|
current_end = end |
|
else: |
|
|
|
if current_start is not None and current_end is not None: |
|
formatted_text = f'{current_speaker} ({current_start} : {current_end}) : {" ".join(current_text)}' |
|
else: |
|
formatted_text = f'{current_speaker} : {" ".join(current_text)}' |
|
result.append(formatted_text) |
|
|
|
|
|
current_speaker = speaker |
|
current_start = start |
|
current_end = end |
|
current_text = [word] |
|
|
|
|
|
if current_text: |
|
if current_start is not None and current_end is not None: |
|
formatted_text = f'{current_speaker} ({current_start} : {current_end}) : {" ".join(current_text)}' |
|
else: |
|
formatted_text = f'{current_speaker} : {" ".join(current_text)}' |
|
result.append(formatted_text) |
|
|
|
return '\n'.join(result) |
|
|
|
st.title('Audio Transcription App') |
|
st.sidebar.title("Settings") |
|
|
|
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1) |
|
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16) |
|
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0) |
|
|
|
ACCESS_TOKEN = st.secrets["HF_TOKEN"] |
|
|
|
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"]) |
|
|
|
if uploaded_file is not None: |
|
st.audio(uploaded_file) |
|
file_extension = uploaded_file.name.split(".")[-1] |
|
temp_file_path = f"temp_file.{file_extension}" |
|
|
|
with open(temp_file_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
with st.spinner('Транскрибируем...'): |
|
|
|
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type) |
|
|
|
audio = whisperx.load_audio(temp_file_path) |
|
result = model.transcribe(audio, batch_size=batch_size, language="ru") |
|
print('Transcribed, now aligning') |
|
|
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) |
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) |
|
print('Aligned, now diarizing') |
|
|
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device) |
|
diarize_segments = diarize_model(audio) |
|
result_diar = whisperx.assign_word_speakers(diarize_segments, result) |
|
|
|
st.write("Результат транскрибации:") |
|
transcript = convert_segments_object_to_text(result_diar) |
|
st.text(transcript) |
|
|
|
with st.spinner('Резюмируем...'): |
|
username = st.secrets["GIGA_USERNAME"] |
|
password = st.secrets["GIGA_SECRET"] |
|
|
|
|
|
auth_str = f'{username}:{password}' |
|
auth_bytes = auth_str.encode('utf-8') |
|
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8') |
|
url = os.getenv('GIGA_AUTH_URL') |
|
|
|
headers = { |
|
'Authorization': f'Basic {auth_base64}', |
|
'RqUID': os.getenv('GIGA_rquid'), |
|
'Content-Type': 'application/x-www-form-urlencoded', |
|
'Accept': 'application/json' |
|
} |
|
|
|
data = { |
|
'scope': os.getenv('GIGA_SCOPE') |
|
} |
|
|
|
response = requests.post(url, headers=headers, data=data, verify=False) |
|
access_token = response.json()['access_token'] |
|
print('Got access token') |
|
|
|
url_completion = os.getenv('GIGA_COMPLETION_URL') |
|
|
|
data_copm = json.dumps({ |
|
"model": os.getenv('GIGA_MODEL'), |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": os.getenv('GIGA_BASE_PROMPT') + transcript |
|
} |
|
], |
|
"stream": False, |
|
"max_tokens": int(os.getenv('GIGA_MAX_TOKENS')), |
|
}) |
|
|
|
headers_comp = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json', |
|
'Authorization': 'Bearer ' + access_token |
|
} |
|
|
|
response = requests.post(url_completion, headers=headers_comp, data=data_copm, verify=False) |
|
response_data = response.json() |
|
answer_from_llm = response_data['choices'][0]['message']['content'] |
|
|
|
st.write("Результат резюмирования:") |
|
st.text(answer_from_llm) |