import streamlit as st import streamlit.components.v1 as component from googletrans import Translator from model import load_model # from huggingface_hub import snapshot_download page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"]) translator = Translator() seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්') seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5) max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100) gen_bt = st.sidebar.button('Generate') def generate(model, tokenizer, seed, seq_num, max_len): sentences = [] input_ids = tokenizer.encode(seed, return_tensors='pt') beam_outputs = model.generate( input_ids, do_sample=True, max_length=max_len, top_k=50, top_p=0.95, temperature=0.7, num_return_sequences=seq_num, no_repeat_ngram_size=2, early_stopping=True ) for beam_out in beam_outputs: sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True)) return sentences def html(body): st.markdown(body, unsafe_allow_html=True) def card_begin_str(Sinhala_sentence): return ( "" '
' '
' f'{Sinhala_sentence}' ) def card_end_str(): return "
" def card(sinhala_sentence, english_sentence): lines = [card_begin_str(sinhala_sentence), f"

{english_sentence}

", card_end_str()] html("".join(lines)) def br(n): html(n * "
") def card_html(sinhala_sentence, english_sentence): with open('./app.css') as f: css_file = f.read() return component.html( f"""

{sinhala_sentence}

English Translations are by Google Translate

{english_sentence}

""" ) if page == 'Pretrained GPT2': st.title('Sinhala Text generation with GPT2') st.markdown('A simple demo using [Sinhala-gpt2 model](https://huggingface.co/flax-community/Sinhala-gpt2) trained during hf-flax week') model, tokenizer = load_model('flax-community/Sinhala-gpt2') if gen_bt: try: with st.spinner('Generating...'): # generator = pipeline('text-generation', model=model, tokenizer=tokenizer) # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num) seqs = generate(model, tokenizer, seed, seq_num, max_len) st.warning("English sentences were translated by Google Translate.") for i, seq in enumerate(seqs): english_sentence = translator.translate(seq, src='si', dest='en').text # card(seq, english_sentence) html(card_begin_str(seq)) st.info(english_sentence) html(card_end_str()) except Exception as e: st.exception(f'Exception: {e}') else: st.title('Sinhala Text generation with Finetuned GPT2') st.markdown('This model has been [finetuned Sinhala-gpt2 model](https://huggingface.co/keshan/sinhala-gpt2-newswire) with 6000 news articles(~12MB)') model, tokenizer = load_model('keshan/sinhala-gpt2-newswire') if gen_bt: try: with st.spinner('Generating...'): # generator = pipeline('text-generation', model=model, tokenizer=tokenizer) # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num) seqs = generate(model, tokenizer, seed, seq_num, max_len) st.warning("English sentences were translated by Google Translate.") for i, seq in enumerate(seqs): # st.info(f'Generated sequence {i+1}:') # st.write(seq) # st.info(f'English translation (by Google Translation):') # st.write(translator.translate(seq, src='si', dest='en').text) english_sentence = translator.translate(seq, src='si', dest='en').text # card(seq, english_sentence) html(card_begin_str(seq)) st.info(english_sentence) html(card_end_str()) except Exception as e: st.exception(f'Exception: {e}') st.markdown('____________')