keshan's picture
Update app.py
4fc8074
import streamlit as st
import streamlit.components.v1 as component
from googletrans import Translator
from model import load_model
# from huggingface_hub import snapshot_download
page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"])
translator = Translator()
seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5)
max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100)
gen_bt = st.sidebar.button('Generate')
def generate(model, tokenizer, seed, seq_num, max_len):
sentences = []
input_ids = tokenizer.encode(seed, return_tensors='pt')
beam_outputs = model.generate(
input_ids,
do_sample=True,
max_length=max_len,
top_k=50,
top_p=0.95,
temperature=0.7,
num_return_sequences=seq_num,
no_repeat_ngram_size=2,
early_stopping=True
)
for beam_out in beam_outputs:
sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True))
return sentences
def html(body):
st.markdown(body, unsafe_allow_html=True)
def card_begin_str(Sinhala_sentence):
return (
"<style>div.card{background-color:#023b1d;border-radius: 5px;box-shadow: 0 4px 8px 0 rgba(0,0,0,0.2);transition: 0.3s;} small{ margin: 5px;}</style>"
'<div class="card">'
'<div class="container">'
f'<small><font color="white">{Sinhala_sentence}</font></small>'
)
def card_end_str():
return "</div></div>"
def card(sinhala_sentence, english_sentence):
lines = [card_begin_str(sinhala_sentence), f"<p>{english_sentence}</p>", card_end_str()]
html("".join(lines))
def br(n):
html(n * "<br>")
def card_html(sinhala_sentence, english_sentence):
with open('./app.css') as f:
css_file = f.read()
return component.html(
f"""
<style>{css_file}</style>
<article class="class_1 bg-white rounded-lg p-4 relative">
<p class="font-bold items-center text-sm text-primary relative mb-1">{sinhala_sentence}</p>
<div class="flex items-center text-white-400 mb-4">
<i class="fab fa-google mx-2"></i>
<small class="text-white-400">English Translations are by Google Translate</small>
</div>
<p class="not-italic items-center text-sm text-primary relative mb-4">
{english_sentence}
</p>
</article>
"""
)
if page == 'Pretrained GPT2':
st.title('Sinhala Text generation with GPT2')
st.markdown('A simple demo using [Sinhala-gpt2 model](https://huggingface.co/flax-community/Sinhala-gpt2) trained during hf-flax week')
model, tokenizer = load_model('flax-community/Sinhala-gpt2')
if gen_bt:
try:
with st.spinner('Generating...'):
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
seqs = generate(model, tokenizer, seed, seq_num, max_len)
st.warning("English sentences were translated by Google Translate.")
for i, seq in enumerate(seqs):
english_sentence = translator.translate(seq, src='si', dest='en').text
# card(seq, english_sentence)
html(card_begin_str(seq))
st.info(english_sentence)
html(card_end_str())
except Exception as e:
st.exception(f'Exception: {e}')
else:
st.title('Sinhala Text generation with Finetuned GPT2')
st.markdown('This model has been [finetuned Sinhala-gpt2 model](https://huggingface.co/keshan/sinhala-gpt2-newswire) with 6000 news articles(~12MB)')
model, tokenizer = load_model('keshan/sinhala-gpt2-newswire')
if gen_bt:
try:
with st.spinner('Generating...'):
# generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
seqs = generate(model, tokenizer, seed, seq_num, max_len)
st.warning("English sentences were translated by Google Translate.")
for i, seq in enumerate(seqs):
# st.info(f'Generated sequence {i+1}:')
# st.write(seq)
# st.info(f'English translation (by Google Translation):')
# st.write(translator.translate(seq, src='si', dest='en').text)
english_sentence = translator.translate(seq, src='si', dest='en').text
# card(seq, english_sentence)
html(card_begin_str(seq))
st.info(english_sentence)
html(card_end_str())
except Exception as e:
st.exception(f'Exception: {e}')
st.markdown('____________')