Spaces:
Runtime error
Runtime error
import streamlit as st | |
from sampler import GPT2SentencesGenerator, SamplingSentencesGenerator, GreedySentencesGenerator, BeamSentencesGenerator | |
import meta | |
st.set_page_config(layout="wide") | |
def get_sentences_generator(method): | |
model_dir = "antoiloui/belgpt2" | |
if method =='sampling': | |
return SamplingSentencesGenerator(model_dir) | |
elif method == 'greedy': | |
return GreedySentencesGenerator(model_dir) | |
elif method == 'beam': | |
return BeamSentencesGenerator(model_dir) | |
else: | |
return NotImplementedError | |
def display_parameters_from_generation_method(methode): | |
user_input = st.text_input('Texte de départ (peut être vide)', "les modèles génératifs sont cool !") | |
if methode == 'sampling': | |
user_nsamples = st.number_input("Nombre de phrases", min_value=1, max_value = 20) | |
user_temperature = st.slider("Choisissez une température : le degré de 'folie' du texte", min_value = 0.01, max_value = 1.5, value = 0.7) | |
user_top_k = st.slider(" TOP K : choisissez parmi les K mots les plus probables dans la génération", min_value = 0, max_value = 1000, value=0) | |
user_top_p = st.slider(" TOP P : choisissez parmi le pourcentage des mots les plus probables dans la génération", min_value = 0.5, max_value = 0.99, value = 0.9) | |
args_dict= {"contexte":user_input, "nsamples":user_nsamples, "temperature":user_temperature, "top_p":user_top_p, "top_k":user_top_k} | |
elif methode == 'greedy': | |
args_dict= {"contexte":user_input} | |
elif methode == 'beam': | |
user_num_beams = st.number_input("Nombre de faisceaux de probabilités", min_value = 2, max_value = 10) | |
user_nsamples = st.number_input("Nombre de phrases", min_value=1, max_value = 10) | |
args_dict= {"contexte":user_input, "num_beams":user_num_beams, "nsamples":user_nsamples} | |
else: | |
st.write("Les autres méthodes arrivent !!!") | |
return args_dict | |
def display_principles(): | |
for k,sentences in meta.body.items(): | |
st.header(k) | |
for s in sentences: | |
st.write(s) | |
def main(): | |
st.title(meta.TITLE) | |
display_principles() | |
methode = st.selectbox("Choisissez votre méthode de génération", ['sampling', 'greedy', 'beam']) | |
generator = get_sentences_generator(methode) | |
st.header(f"Paramètres de la méthode __{methode}__") | |
args_dict = display_parameters_from_generation_method(methode) | |
if st.button('Parle, beau parleur !'): | |
res = generator.generate(**args_dict) | |
for texte in res: | |
st.write(texte) | |
if __name__=='__main__': | |
main() | |