File size: 4,194 Bytes
7d2c473
23076a4
3784695
56f4168
719c3e8
d1b8f9e
eeaa144
83801b3
4773685
eb683c5
d3a21fe
a75c652
 
8b2bfa8
d1b8f9e
 
 
2306b39
 
 
f2c3ba6
2306b39
 
59b484e
2306b39
 
686ef78
 
 
2306b39
686ef78
1659f5a
686ef78
 
d3a21fe
686ef78
f2c3ba6
686ef78
eeaa144
1659f5a
3e64240
093b41a
4d25eaf
 
 
 
093b41a
a75c652
 
093b41a
a75c652
3e64240
8307fd0
07a4ffd
3625f99
 
23076a4
 
 
 
 
 
 
 
 
 
 
d3a21fe
 
23076a4
093b41a
23076a4
07a4ffd
ae41022
07a4ffd
 
 
 
 
 
 
 
 
 
 
 
2cf4176
07a4ffd
 
 
 
 
8307fd0
5a73194
 
 
 
 
 
 
 
 
 
 
3784695
b39afbd
d1b8f9e
 
2306b39
2f2107c
d1b8f9e
 
 
 
 
 
 
2306b39
 
5a73194
2306b39
c8a52a4
c7b9359
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import base64
import io
from huggingface_hub import InferenceClient
from gtts import gTTS
from audio_recorder_streamlit import audio_recorder
import speech_recognition as sr
from pydub import AudioSegment

pre_prompt_text = "Hablarás español, tus principios el estoicismo, eres una IA conductual, tus respuestas serán breves."

if "history" not in st.session_state:
    st.session_state.history = []

if "pre_prompt_sent" not in st.session_state:
    st.session_state.pre_prompt_sent = False

def recognize_speech(audio_data, show_messages=True):
    recognizer = sr.Recognizer()
    audio_recording = sr.AudioFile(audio_data)

    with audio_recording as source:
        audio = recognizer.record(source)

    try:
        audio_text = recognizer.recognize_google(audio, language="es-ES")
        if show_messages:
            st.subheader("Texto Reconocido:")
            st.write(audio_text)
            st.success("Reconocimiento de voz completado.")
    except sr.UnknownValueError:
        st.warning("No se pudo reconocer el audio. ¿Intentaste grabar algo?")
        audio_text = ""
    except sr.RequestError:
        st.error("Háblame para comenzar!")
        audio_text = ""

    return audio_text

def format_prompt(message, history):
    prompt = "<s>"

    if not st.session_state.pre_prompt_sent:
        prompt += f"[INST] {pre_prompt_text} [/INST]"
        st.session_state.pre_prompt_sent = True

    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "

    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(audio_text, history, temperature=None, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    temperature = float(temperature) if temperature is not None else 0.9
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(audio_text, history)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
    response = ""

    for response_token in stream:
        response += response_token.token.text
    
    response = ' '.join(response.split()).replace('</s>', '')
    audio_file = text_to_speech(response, speed=1.3)
    return response, audio_file

def text_to_speech(text, speed=1.3):
    tts = gTTS(text=text, lang='es')
    audio_fp = io.BytesIO()
    tts.write_to_fp(audio_fp)
    audio_fp.seek(0)
    audio = AudioSegment.from_mp3(audio_fp)
    modified_speed_audio = audio.speedup(playback_speed=speed)
    modified_audio_fp = io.BytesIO()
    modified_speed_audio.export(modified_audio_fp, format="mp3")
    modified_audio_fp.seek(0)
    return modified_audio_fp

def audio_play(audio_fp):
    st.audio(audio_fp.read(), format="audio/mp3", start_time=0)

def display_recognition_result(audio_text, output, audio_file):
    if audio_text:
        st.session_state.history.append((audio_text, output))  

    if audio_file is not None:
        st.markdown(
            f"""<audio autoplay="autoplay" controls="controls" src="data:audio/mp3;base64,{base64.b64encode(audio_file.read()).decode()}" type="audio/mp3" id="audio_player"></audio>""",
            unsafe_allow_html=True)

def main():
    if not st.session_state.pre_prompt_sent:
        st.session_state.pre_prompt_sent = True

    audio_bytes = audio_recorder(text="", recording_color="#b81414", neutral_color="#808080",)  
    
    if audio_bytes:
        st.audio(audio_bytes, format="audio/wav")
        audio_data = io.BytesIO(audio_bytes)  
        audio_data.seek(0)  
        audio_text = recognize_speech(audio_data)
        
        if audio_text:
            output, audio_file = generate(audio_text, history=st.session_state.history)  
            display_recognition_result(audio_text, output, audio_file)

if __name__ == "__main__":
    main()