|
import streamlit as st |
|
from huggingface_hub import InferenceClient |
|
import time |
|
import re |
|
import edge_tts |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
import tempfile |
|
from pydub import AudioSegment |
|
|
|
|
|
client_hf = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
|
|
async def text_to_speech_edge(text, language_code): |
|
voice = {"fr": "fr-FR-RemyMultilingualNeural"}[language_code] |
|
communicate = edge_tts.Communicate(text, voice) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: |
|
tmp_path = tmp_file.name |
|
await communicate.save(tmp_path) |
|
return tmp_path |
|
|
|
|
|
def run_in_threadpool(func, *args, **kwargs): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
future = asyncio.ensure_future(func(*args, **kwargs)) |
|
return loop.run_until_complete(future) |
|
|
|
def concatenate_audio(paths): |
|
combined = AudioSegment.empty() |
|
for path in paths: |
|
audio = AudioSegment.from_mp3(path) |
|
combined += audio |
|
combined_path = tempfile.mktemp(suffix=".mp3") |
|
combined.export(combined_path, format="mp3") |
|
return combined_path |
|
|
|
|
|
def dictee_to_audio_segmented(dictee): |
|
sentences = segmenter_texte(dictee) |
|
audio_urls = [] |
|
with ThreadPoolExecutor() as executor: |
|
for sentence in sentences: |
|
processed_sentence = replace_punctuation(sentence) |
|
audio_path = executor.submit(run_in_threadpool, text_to_speech_edge, processed_sentence, "fr").result() |
|
audio_urls.append(audio_path) |
|
return audio_urls |
|
|
|
def generer_dictee(classe, longueur): |
|
prompt = f"Créer une dictée pour la classe {classe} d'une longueur d'environ {longueur} mots. Il est important de créer le texte uniquement de la dictée et de ne pas ajouter de consignes ou d'indications supplémentaires." |
|
generate_kwargs = { |
|
"temperature": 0.7, |
|
"max_new_tokens": 1000, |
|
"top_p": 0.95, |
|
"repetition_penalty": 1.2, |
|
"do_sample": True, |
|
} |
|
formatted_prompt = f"<s>[INST] {prompt} [/INST]" |
|
stream = client_hf.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
dictee = "" |
|
for response in stream: |
|
dictee += response.token.text |
|
dictee = dictee.replace("</s>", "").strip() |
|
return dictee |
|
|
|
def correction_dictee(dictee, dictee_utilisateur): |
|
prompt = f"Voici une dictée crée: {dictee} | Voici la dictée faite par l'utilisateur : {dictee_utilisateur} - Corrige la dictée en donnant les explications, utilise les syntax du markdown pour une meilleur comprehesion de la correction." |
|
generate_kwargs = { |
|
"temperature": 0.7, |
|
"max_new_tokens": 2000, |
|
"top_p": 0.95, |
|
"repetition_penalty": 1.2, |
|
"do_sample": True, |
|
} |
|
formatted_prompt = f"<s>[INST] {prompt} [/INST]" |
|
stream = client_hf.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
texte_ameliore = "" |
|
for response in stream: |
|
texte_ameliore += response.token.text |
|
texte_ameliore = texte_ameliore.replace("</s>", "").strip() |
|
return correction |
|
|
|
|
|
def replace_punctuation(text): |
|
replacements = { |
|
".": " point.", |
|
",": " virgule,", |
|
";": " point-virgule;", |
|
":": " deux-points:", |
|
"!": " point d'exclamation!", |
|
"?": " point d'interrogation?", |
|
"-": " tiret-", |
|
"'": " apostrophe'", |
|
} |
|
for key, value in replacements.items(): |
|
text = text.replace(key, value) |
|
return text |
|
|
|
def segmenter_texte(texte): |
|
sentences = re.split(r'(?<=[.!?]) +', texte) |
|
return sentences |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.title('Générateur de Dictée') |
|
|
|
with st.expander("Paramètres de la dictée", expanded=True): |
|
mode = st.radio("Mode:", ["S'entrainer: Vous aurez uniquement les audios suivi d'une correction par IA (Pour 1 seul personne)", "Entrainer: Vous aurez uniquement le texte de la dictée pour entrainer quelqu'un d'autre (Pour 2 ou + personnes)"]) |
|
classe = st.selectbox("Classe", ["CP", "CE1", "CE2", "CM1", "CM2", "6ème", "5ème", "4ème", "3ème", "Seconde", "Premiere", "Terminale"], index=2) |
|
longueur = st.slider("Longueur de la dictée (nombre de mots)", 50, 500, 200) |
|
|
|
if st.button('Générer la Dictée'): |
|
with st.spinner("Génération de la dictée en cours..."): |
|
dictee = generer_dictee(classe, longueur) |
|
if mode == "S'entrainer: Vous aurez uniquement les audios suivi d'une correction par IA (Pour 1 seul personne)": |
|
audio_urls = dictee_to_audio_segmented(dictee) |
|
concatenated_audio_path = concatenate_audio(audio_urls) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.audio(concatenated_audio_path, format='audio/wav', start_time=0) |
|
with st.expander("Phrases de la Dictée"): |
|
for idx, url in enumerate(audio_urls, start=1): |
|
st.markdown(f"**Phrase {idx}:**") |
|
st.audio(url, format='audio/wav') |
|
|
|
with col2: |
|
dictee_utilisateur = st.text_area("Écrivez la dictée ici:", height=300) |
|
if st.button('Correction'): |
|
st.write("Dictée originale:") |
|
correction = correction_dictee(dictee, dictee_utilisateur) |
|
st.text_area("Voici la correction :", correction, height=500) |
|
|
|
elif mode == "Entrainer: Vous aurez uniquement le texte de la dictée pour entrainer quelqu'un d'autre (Pour 2 ou + personnes)": |
|
st.text_area("Voici votre dictée :", dictee, height=300) |