File size: 4,883 Bytes
f8987ba
ca404cd
06452a1
 
ca404cd
d6f4621
f8987ba
77b63e6
06452a1
f8987ba
d6f4621
06452a1
 
77b63e6
8f192a0
06452a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca404cd
 
 
 
 
 
 
 
 
 
 
4fc8074
ca404cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06452a1
77b63e6
 
549927d
f8987ba
77b63e6
d6f4621
 
77b63e6
d6f4621
 
06452a1
 
 
ca404cd
06452a1
ca404cd
 
 
 
 
d6f4621
 
 
77b63e6
 
549927d
77b63e6
 
d6f4621
 
77b63e6
d6f4621
 
06452a1
 
 
ca404cd
06452a1
ca404cd
 
 
 
 
 
 
 
 
d6f4621
 
f8987ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import streamlit.components.v1 as component

from googletrans import Translator
from model import load_model
# from huggingface_hub import snapshot_download

page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"])
translator = Translator()

seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5)
max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100)
gen_bt = st.sidebar.button('Generate')

def generate(model, tokenizer, seed, seq_num, max_len):
    sentences = []
    input_ids = tokenizer.encode(seed, return_tensors='pt')
    beam_outputs = model.generate(
        input_ids, 
        do_sample=True, 
        max_length=max_len, 
        top_k=50, 
        top_p=0.95, 
        temperature=0.7,
        num_return_sequences=seq_num,
        no_repeat_ngram_size=2,
        early_stopping=True
    )
    for beam_out in beam_outputs:
        sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True))
    return sentences


def html(body):
    st.markdown(body, unsafe_allow_html=True)


def card_begin_str(Sinhala_sentence):
    return (
        "<style>div.card{background-color:#023b1d;border-radius: 5px;box-shadow: 0 4px 8px 0 rgba(0,0,0,0.2);transition: 0.3s;} small{ margin: 5px;}</style>"
        '<div class="card">'
        '<div class="container">'
        f'<small><font color="white">{Sinhala_sentence}</font></small>'
    )


def card_end_str():
    return "</div></div>"


def card(sinhala_sentence, english_sentence):
    lines = [card_begin_str(sinhala_sentence), f"<p>{english_sentence}</p>", card_end_str()]
    html("".join(lines))


def br(n):
    html(n * "<br>")

def card_html(sinhala_sentence, english_sentence):
    with open('./app.css') as f:
        css_file = f.read()
    return component.html(
    f"""
    <style>{css_file}</style>
    <article class="class_1 bg-white rounded-lg p-4 relative">
    <p class="font-bold items-center text-sm text-primary relative mb-1">{sinhala_sentence}</p>
    
    <div class="flex items-center text-white-400 mb-4">
    <i class="fab fa-google mx-2"></i>
      <small class="text-white-400">English Translations are by Google Translate</small>   
    </div>
     
    <p class="not-italic items-center text-sm text-primary relative mb-4">
      {english_sentence}
    </p>
  </article>
    """
    )
    
if page == 'Pretrained GPT2':
    st.title('Sinhala Text generation with GPT2')
    st.markdown('A simple demo using [Sinhala-gpt2 model](https://huggingface.co/flax-community/Sinhala-gpt2) trained during hf-flax week')

    model, tokenizer = load_model('flax-community/Sinhala-gpt2')


    if gen_bt:
        try:
            with st.spinner('Generating...'):
                # generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
                # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
                seqs = generate(model, tokenizer, seed, seq_num, max_len)
            st.warning("English sentences were translated by Google Translate.")
            for i, seq in enumerate(seqs):
                english_sentence = translator.translate(seq, src='si', dest='en').text
                # card(seq, english_sentence)
                html(card_begin_str(seq))
                st.info(english_sentence)
                html(card_end_str())
        except Exception as e:
            st.exception(f'Exception: {e}')
else:
    
    st.title('Sinhala Text generation with Finetuned GPT2')
    st.markdown('This model has been [finetuned Sinhala-gpt2 model](https://huggingface.co/keshan/sinhala-gpt2-newswire) with 6000 news articles(~12MB)')
    
    model, tokenizer = load_model('keshan/sinhala-gpt2-newswire')


    if gen_bt:
        try:
            with st.spinner('Generating...'):
                # generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
                # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
                seqs = generate(model, tokenizer, seed, seq_num, max_len)
            st.warning("English sentences were translated by Google Translate.")
            for i, seq in enumerate(seqs):
                # st.info(f'Generated sequence {i+1}:')
                # st.write(seq)
                # st.info(f'English translation (by Google Translation):')
                # st.write(translator.translate(seq, src='si', dest='en').text)
                english_sentence = translator.translate(seq, src='si', dest='en').text
                # card(seq, english_sentence)
                html(card_begin_str(seq))
                st.info(english_sentence)
                html(card_end_str())
        except Exception as e:
            st.exception(f'Exception: {e}')
st.markdown('____________')