keshan commited on
Commit
d6f4621
1 Parent(s): 615ba9a

Adding finetuned model as a demo

Browse files
Files changed (1) hide show
  1. app.py +49 -18
app.py CHANGED
@@ -1,28 +1,59 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
3
 
4
- st.title('Sinhala Text generation with GPT2')
5
- st.markdown('A simple demo using Sinhala-gpt2 model trained during hf-flax week')
6
 
7
- seed = st.text_input('Starting text', 'ආයුබෝවන්')
8
- seq_num = st.number_input('Number of sentences to generate ', 1, 20, 5)
9
- max_len = st.number_input('Length of the sentence ', 5, 300, 100)
 
 
 
 
10
 
11
- go = st.button('Generate')
12
- with st.spinner('Waiting for the model to load.....'):
13
- model = AutoModelForCausalLM.from_pretrained('flax-community/Sinhala-gpt2')
14
- tokenizer = AutoTokenizer.from_pretrained('flax-community/Sinhala-gpt2')
15
- st.success('Model loaded!!')
16
 
 
 
 
 
 
 
 
 
17
 
18
- if go:
19
- try:
20
- with st.spinner('Generating...'):
21
- generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
22
- seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
23
- st.write(seqs)
24
- except Exception as e:
25
- st.exception(f'Exception: {e}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  st.markdown('____________')
28
  st.markdown('by Keshan with Flax Community')
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ # from huggingface_hub import snapshot_download
4
 
5
+ page = st.sidebar.selectbox("Model ", ["Pretrained GPT2", "Finetuned on News data"])
 
6
 
7
+ def load_model(model_name):
8
+ with st.spinner('Waiting for the model to load.....'):
9
+ # snapshot_download('flax-community/Sinhala-gpt2')
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ st.success('Model loaded!!')
13
+ return model, tokenizer
14
 
15
+ seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
16
+ seq_num = st.sidebar.number_input('Number of sentences to generate ', 1, 20, 5)
17
+ max_len = st.sidebar.number_input('Length of the sentence ', 5, 300, 100)
 
 
18
 
19
+ if page == "Finetuned on News data":
20
+
21
+ st.title('Sinhala Text generation with Finetuned GPT2')
22
+ st.markdown('This model has been finetuned Sinhala-gpt2 model with 6000 news articles(~12MB)')
23
+
24
+ # seed = st.text_input('Starting text', 'ආයුබෝවන්')
25
+ # seq_num = st.number_input('Number of sentences to generate ', 1, 20, 5)
26
+ # max_len = st.number_input('Length of the sentence ', 5, 300, 100)
27
 
28
+ gen_news = st.button('Generate')
29
+ model, tokenizer = load_model('keshan/sinhala-gpt2-newswire')
30
+
31
+
32
+ if gen_news:
33
+ try:
34
+ with st.spinner('Generating...'):
35
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
36
+ seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
37
+ st.write(seqs)
38
+ except Exception as e:
39
+ st.exception(f'Exception: {e}')
40
+ else:
41
+ st.title('Sinhala Text generation with GPT2')
42
+ st.markdown('A simple demo using Sinhala-gpt2 model trained during hf-flax week')
43
+
44
+ gen_gpt2 = st.button('Generate')
45
+ model, tokenizer = load_model('flax-community/Sinhala-gpt2')
46
+
47
+
48
+ if gen_gpt2:
49
+ try:
50
+ with st.spinner('Generating...'):
51
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
52
+ seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
53
+ st.write(seqs)
54
+ except Exception as e:
55
+ st.exception(f'Exception: {e}')
56
+
57
 
58
  st.markdown('____________')
59
  st.markdown('by Keshan with Flax Community')