import streamlit as st from .services import TextGeneration from tokenizers import Tokenizer from functools import lru_cache # @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str}) @lru_cache(maxsize=1) def load_text_generator(): generator = TextGeneration() generator.load() return generator generator = load_text_generator() qa_prompt = """ أجب عن السؤال التالي: """ qa_prompt_post = """ الجواب هو """ qa_prompt_post_year = """ في سنة: """ def write(): st.markdown( """
Use the generation paramters on the sidebar to adjust generation quality.
""", unsafe_allow_html=True, ) # col[0].write( # "AraGPT2 is trained from screatch on 77GB of Arabic text. More details in our [repo](https://github.com/aub-mind/arabert/tree/master/aragpt2)." # ) # st.write("## Generate Arabic Text") st.markdown( """ """, unsafe_allow_html=True, ) prompt = st.text_area( "Prompt", "يحكى أن مزارعا مخادعا قام ببيع بئر الماء الموجود في أرضه لجاره مقابل مبلغ كبير من المال", ) if st.button("Generate"): with st.spinner("Generating..."): generated_text = generator.generate( prompt=prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size, ) st.write(generated_text) st.markdown("---") st.subheader("") st.markdown( """
Adjust the maximum length to closely match the expected output length. Setting the Sampling paramter to False is recommended
""", unsafe_allow_html=True, ) question = st.text_input( "Question", "من كان رئيس ألمانيا النازية في الحرب العالمية الثانية ؟" ) is_date = st.checkbox("Help the model: Is the answer a date?") if st.button("Answer"): prompt2 = qa_prompt + question + qa_prompt_post if is_date: prompt2 += qa_prompt_post_year else: prompt2 += " : " with st.spinner("Thinking..."): answer = generator.generate( prompt=prompt2, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=do_sample, num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size, ) st.write(answer)