keshan commited on
Commit
06452a1
1 Parent(s): 77b63e6

modifying generation pipeline

Browse files
Files changed (2) hide show
  1. app.py +40 -13
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,22 +1,43 @@
1
  import streamlit as st
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  # from huggingface_hub import snapshot_download
4
 
5
  page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"])
 
6
 
7
  def load_model(model_name):
8
  with st.spinner('Waiting for the model to load.....'):
9
  # snapshot_download('flax-community/Sinhala-gpt2')
10
- model = AutoModelForCausalLM.from_pretrained(model_name)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
12
  st.success('Model loaded!!')
13
  return model, tokenizer
14
 
15
  seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
16
- seq_num = st.sidebar.number_input('Number of sentences to generate ', 1, 20, 5)
17
- max_len = st.sidebar.number_input('Length of the sentence ', 5, 300, 100)
18
  gen_bt = st.sidebar.button('Generate')
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if page == 'Pretrained GPT2':
21
  st.title('Sinhala Text generation with GPT2')
22
  st.markdown('A simple demo using Sinhala-gpt2 model trained during hf-flax week')
@@ -27,9 +48,14 @@ if page == 'Pretrained GPT2':
27
  if gen_bt:
28
  try:
29
  with st.spinner('Generating...'):
30
- generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
31
- seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
32
- st.write(seqs)
 
 
 
 
 
33
  except Exception as e:
34
  st.exception(f'Exception: {e}')
35
  else:
@@ -43,13 +69,14 @@ else:
43
  if gen_bt:
44
  try:
45
  with st.spinner('Generating...'):
46
- generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
47
- seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
48
- st.write(seqs)
 
 
 
 
 
49
  except Exception as e:
50
  st.exception(f'Exception: {e}')
51
-
52
-
53
  st.markdown('____________')
54
- st.markdown('by Keshan with Flax Community')
55
-
1
  import streamlit as st
2
+
3
+ from googletrans import Translator
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  # from huggingface_hub import snapshot_download
6
 
7
  page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"])
8
+ translator = Translator()
9
 
10
  def load_model(model_name):
11
  with st.spinner('Waiting for the model to load.....'):
12
  # snapshot_download('flax-community/Sinhala-gpt2')
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
15
  st.success('Model loaded!!')
16
  return model, tokenizer
17
 
18
  seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්')
19
+ seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5)
20
+ max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100)
21
  gen_bt = st.sidebar.button('Generate')
22
 
23
+ def generate(model, tokenizer, seed, seq_num, max_len):
24
+ sentences = []
25
+ input_ids = tokenizer.encode(seed, return_tensors='pt')
26
+ beam_outputs = model.generate(
27
+ input_ids,
28
+ do_sample=True,
29
+ max_length=max_len,
30
+ top_k=50,
31
+ top_p=0.95,
32
+ temperature=0.7,
33
+ num_return_sequences=seq_num,
34
+ no_repeat_ngram_size=2,
35
+ early_stopping=True
36
+ )
37
+ for beam_out in beam_outputs:
38
+ sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True))
39
+ return sentences
40
+
41
  if page == 'Pretrained GPT2':
42
  st.title('Sinhala Text generation with GPT2')
43
  st.markdown('A simple demo using Sinhala-gpt2 model trained during hf-flax week')
48
  if gen_bt:
49
  try:
50
  with st.spinner('Generating...'):
51
+ # generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
52
+ # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
53
+ seqs = generate(model, tokenizer, seed, seq_num, max_len)
54
+ for i, seq in enumerate(seqs):
55
+ st.info(f'Generated sequence {i+1}:')
56
+ st.write(seq)
57
+ st.info(f'English translation (by Google Translation):')
58
+ st.write(translator.translate(seq, src='si', dest='en').text)
59
  except Exception as e:
60
  st.exception(f'Exception: {e}')
61
  else:
69
  if gen_bt:
70
  try:
71
  with st.spinner('Generating...'):
72
+ # generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
73
+ # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num)
74
+ seqs = generate(model, tokenizer, seed, seq_num, max_len)
75
+ for i, seq in enumerate(seqs):
76
+ st.info(f'Generated sequence {i+1}:')
77
+ st.write(seq)
78
+ st.info(f'English translation (by Google Translation):')
79
+ st.write(translator.translate(seq, src='si', dest='en').text)
80
  except Exception as e:
81
  st.exception(f'Exception: {e}')
 
 
82
  st.markdown('____________')
 
 
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  transformers
2
  streamlit
3
- jax
4
  torch
5
- flax
1
  transformers
2
  streamlit
 
3
  torch
4
+ googletrans==3.1.0a