File size: 5,088 Bytes
8fb0be5
ee531be
59f6126
 
1c8f6d7
e8a4c9c
09b358f
024f740
4b331f0
 
 
e8a4c9c
 
 
 
ee531be
09b358f
 
7dc42bb
09b358f
 
 
29a10e5
c9323c5
 
791adc1
4b331f0
ee531be
4b331f0
ee531be
4b331f0
09b358f
 
 
 
 
060a1e0
 
 
 
 
09b358f
4b331f0
 
 
ee531be
4b331f0
 
09b358f
8f0bb70
09b358f
ee531be
09b358f
 
 
 
 
 
 
4380489
09b358f
 
 
4380489
09b358f
 
 
ee531be
1c8f6d7
060a1e0
09b358f
623c1fa
09b358f
 
623c1fa
4b331f0
e86a1dc
 
 
a8c8823
c9323c5
 
46f1669
c9323c5
 
 
 
 
 
e8a4c9c
c9323c5
 
791adc1
ee531be
791adc1
4b331f0
09b358f
 
 
7bbc5c5
ee531be
4b331f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import streamlit as st
import whisperx
import torch
from utils import convert_segments_object_to_text, check_password, convert_segments_object_to_text_simple
from gigiachat_requests import get_access_token, get_completion_from_gigachat, get_number_of_tokens
from openai_requests import get_completion_from_openai

if check_password():    
    st.title('Audio Transcription App')
    st.sidebar.title("Settings")
    
    device = os.getenv('DEVICE')
    batch_size = int(os.getenv('BATCH_SIZE'))
    compute_type = os.getenv('COMPUTE_TYPE')

    initial_base_prompt = os.getenv('BASE_PROMPT')
    initial_processing_prompt = os.getenv('PROCCESS_PROMPT')

    llm = st.sidebar.selectbox("LLM", ["GigaChat", "Chat GPT"], index=0)
    base_prompt = st.sidebar.text_area("Промпт для резюмирования", value=initial_base_prompt)
    max_tokens_summary = st.sidebar.number_input("Максимальное количество токенов при резюмировании", min_value=1, value=1024)

    enable_processing = st.sidebar.checkbox("Добавить обработку транскрибации", value=False)
    processing_prompt = st.sidebar.text_area("Промпт для обработки транскрибации", value=initial_processing_prompt)

    ACCESS_TOKEN = st.secrets["HF_TOKEN"]

    uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])

    if uploaded_file is not None:
        file_name = uploaded_file.name

        if 'file_name' not in st.session_state or st.session_state.file_name != file_name:
            st.session_state.transcript = ''
            st.session_state.file_name = file_name
            print(st.session_state.file_name)
            print(st.session_state.transcript)

        print(st.session_state.file_name)
        print(st.session_state.transcript)
            
        st.audio(uploaded_file)
        file_extension = uploaded_file.name.split(".")[-1]  # Получаем расширение файла
        temp_file_path = f"temp_file.{file_extension}"  # Создаем временное имя файла с правильным расширением
    
        with open(temp_file_path, "wb") as f:
            f.write(uploaded_file.getbuffer())

        print(st.session_state.transcript)
        if 'transcript' not in st.session_state or st.session_state.transcript == '':
    
            with st.spinner('Транскрибируем...'):
                # Load model
                model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
                # Load and transcribe audio
                audio = whisperx.load_audio(temp_file_path)
                result = model.transcribe(audio, batch_size=batch_size, language="ru")
                print('Transcribed, now aligning')
        
                model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
                result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
                print('Aligned, now diarizing')
        
                diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
                diarize_segments = diarize_model(audio)
                result_diar = whisperx.assign_word_speakers(diarize_segments, result)
        
            transcript = convert_segments_object_to_text_simple(result_diar)
            st.session_state.transcript = transcript
        else:
            
            transcript = st.session_state.transcript
            
        st.write("Результат транскрибации:")
        st.text(transcript)

        if (llm == 'GigaChat'):
            access_token = get_access_token()
    
        if (enable_processing):
            with st.spinner('Обрабатываем транскрибацию...'):

                if (llm == 'GigaChat'):
                    number_of_tokens = get_number_of_tokens(transcript, access_token)
                    print('Количество токенов в транскрибации: ' + str(number_of_tokens))
                    transcript = get_completion_from_gigachat(processing_prompt + transcript, number_of_tokens + 1000, access_token)
                elif (llm == 'Chat GPT'):
                    transcript = get_completion_from_openai(processing_prompt + transcript)
                
                st.write("Результат обработки:")
                st.text(transcript)

        
    
        with st.spinner('Резюмируем...'):
            if (llm == 'GigaChat'):
                summary_answer = get_completion_from_gigachat(base_prompt + transcript, max_tokens_summary, access_token)
            elif (llm == 'Chat GPT'):
                summary_answer = get_completion_from_openai(base_prompt + transcript, max_tokens_summary)
        
            st.write("Результат резюмирования:")
            st.text(summary_answer)