Abinaya Mahendiran commited on
Commit
36338f2
1 Parent(s): ebee998

Updated app

Browse files
Files changed (3) hide show
  1. app.py +61 -15
  2. config.json +5 -3
  3. images/tamil_logo.jpg +0 -0
app.py CHANGED
@@ -3,35 +3,81 @@
3
  """
4
 
5
  # Install necessary libraries
6
- from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline
7
  import streamlit as st
8
  from pprint import pprint
9
  import json
10
 
11
  # Read the config
12
  with open("config.json") as f:
13
- cfg = json.loads(f.read())
14
 
15
  # Set page layout
16
- st.set_page_config(layout="wide")
 
 
 
 
17
 
18
  # Load the model
19
  @st.cache(allow_output_mutation=True)
20
- def load_model():
21
- tokenizer = AutoTokenizer.from_pretrained(cfg["model_name_or_path"])
22
- model = GPT2LMHeadModel.from_pretrained(cfg["model_name_or_path"])
23
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
24
- return generator, tokenizer
 
25
 
26
- with st.spinner('Loading model...'):
27
- generator, tokenizer = load_model()
28
 
29
- # st.image("images/chef-transformer.png", width=400)
 
 
 
 
30
  st.header("Tamil Language Demos")
31
  st.markdown(
32
  "This demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) "
33
- "to show language generation and other downstream tasks"
 
34
  )
35
- img = st.sidebar.image("images/tamil_logo.png", width=100)
36
- add_text_sidebar = st.sidebar.title("Select demo:")
37
- sampling_mode = st.sidebar.selectbox("select a demo", index=0, options=["Text Generation", "Text Classification"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  # Install necessary libraries
6
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline
7
  import streamlit as st
8
  from pprint import pprint
9
  import json
10
 
11
  # Read the config
12
  with open("config.json") as f:
13
+ config = json.loads(f.read())
14
 
15
  # Set page layout
16
+ st.set_page_config(
17
+ page_title="Tamil Language Models",
18
+ layout="wide",
19
+ initial_sidebar_state="expanded"
20
+ )
21
 
22
  # Load the model
23
  @st.cache(allow_output_mutation=True)
24
+ def load_model(model_name):
25
+ with st.spinner('Waiting for the model to load.....'):
26
+ model = AutoModelWithLMHead.from_pretrained(model_name)
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ st.success('Model loaded!!')
29
+ return model, tokenizer
30
 
31
+ # Side bar
32
+ img = st.sidebar.image("images/tamil_logo.jpg", width=380)
33
 
34
+ # Choose the model based on selection
35
+ page = st.sidebar.selectbox("Model", config["models"])
36
+ data = st.sidebar.selectbox("Data", config[page])
37
+
38
+ # Main page
39
  st.header("Tamil Language Demos")
40
  st.markdown(
41
  "This demo uses [GPT2 trained on Oscar dataset](https://huggingface.co/flax-community/gpt-2-tamil) "
42
+ "and [GPT2 trained on Oscar & Indic Corpus dataset] (https://huggingface.co/abinayam/gpt-2-tamil) "
43
+ "to show language generation"
44
  )
45
+
46
+ if page == 'Text Generation' and data == 'Oscar':
47
+ st.title('Tamil text generation with GPT2')
48
+ st.markdown('A simple demo using gpt-2-tamil model trained on Oscar data')
49
+ model, tokenizer = load_model(config[data])
50
+ # Set default options
51
+ seed = st.text_input('Starting text', 'அகர முதல எழுதெல்லம்')
52
+ #seq_num = st.number_input('Number of sentences to generate ', 1, 20, 5)
53
+ max_len = st.number_input('Length of the sentence', 5, 300, 100)
54
+ gen_bt = st.button('Generate')
55
+ if gen_bt:
56
+ try:
57
+ with st.spinner('Generating...'):
58
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
59
+ seqs = generator(seed, max_length=max_len) # num_return_sequences=seq_num)
60
+ st.write(seqs)
61
+ except Exception as e:
62
+ st.exception(f'Exception: {e}')
63
+ elif page == 'Text Generation' and data == "Oscar + Indic Corpus":
64
+ st.title('Tamil text generation with GPT2')
65
+ st.markdown('A simple demo using gpt-2-tamil model trained on Oscar data')
66
+ model, tokenizer = load_model(config[data])
67
+ # Set default options
68
+ seed = st.text_input('Starting text', 'அகர முதல எழுதெல்லம்')
69
+ #seq_num = st.number_input('Number of sentences to generate ', 1, 20, 5)
70
+ max_len = st.number_input('Length of the sentence', 5, 300, 100)
71
+ gen_bt = st.button('Generate')
72
+ if gen_bt:
73
+ try:
74
+ with st.spinner('Generating...'):
75
+ generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
76
+ seqs = generator(seed, max_length=max_len) #num_return_sequences=seq_num)
77
+ st.write(seqs)
78
+ except Exception as e:
79
+ st.exception(f'Exception: {e}')
80
+ else:
81
+ st.title('Tamil News classification with Finetuned GPT2')
82
+ st.markdown('In progress')
83
+
config.json CHANGED
@@ -1,5 +1,7 @@
1
  {
2
- "model_name_or_path": "flax-community/gpt-2-tamil",
3
- "Text Generation": ["example_1", "example_2"],
4
- "Text Classification": ["example_2", "example_2"]
 
 
5
  }
1
  {
2
+ "models": ["Text Generation", "Text Classification"],
3
+ "Text Generation": ["Oscar", "Oscar + Indic Corpus"],
4
+ "Text Classification": ["News Data"],
5
+ "Oscar": "flax-community/gpt-2-tamil",
6
+ "Oscar + Indic Corpus": "abinayam/gpt-2-tamil"
7
  }
images/tamil_logo.jpg ADDED