prateekagrawal commited on
Commit
a28a899
1 Parent(s): ce82499

updated inference.py

Browse files
Files changed (1) hide show
  1. apps/inference.py +30 -22
apps/inference.py CHANGED
@@ -1,50 +1,58 @@
1
  from pandas.io.formats.format import return_docstring
2
  import streamlit as st
3
  import pandas as pd
4
- from transformers import AutoTokenizer,AutoModelForMaskedLM
5
  from transformers import pipeline
6
  import os
7
  import json
8
 
9
- @st.cache(show_spinner=False,persist=True)
10
- def load_model(masked_text,model_name):
 
11
 
12
  model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- nlp = pipeline('fill-mask', model=model, tokenizer=tokenizer)
15
 
16
  result_sentence = nlp(masked_text)
17
 
18
- return result_sentence[0]['sequence']
 
19
 
20
  def app():
21
- st.markdown("<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>", unsafe_allow_html=True)
22
  st.markdown(
23
- "This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
 
 
 
 
24
  )
25
 
26
- target_text_path = './mlm_custom/mlm_targeted_text.csv'
27
  target_text_df = pd.read_csv(target_text_path)
28
-
29
- texts = target_text_df['text']
30
-
31
  st.markdown("""## Select any of the following text : """)
32
- masked_text = st.selectbox('',
33
- texts)
34
 
35
- st.write('You selected:', masked_text)
36
 
37
  models = st.multiselect(
38
  "Choose models",
39
- ['flax-community/roberta-hindi','mrm8488/HindiBERTa','ai4bharat/indic-bert',\
40
- 'neuralspace-reverie/indic-transformers-hi-bert',
41
- 'surajp/RoBERTa-hindi-guj-san'],
42
- ["flax-community/roberta-hindi"]
43
- )
 
 
 
 
44
 
45
  selected_model = models[0]
46
 
47
- if st.button('Fill the Mask!'):
48
- with st.spinner("Filling the Mask..."):
49
- filled_sentence = load_model(masked_text,selected_model)
50
  st.write(filled_sentence)
 
1
  from pandas.io.formats.format import return_docstring
2
  import streamlit as st
3
  import pandas as pd
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
  from transformers import pipeline
6
  import os
7
  import json
8
 
9
+
10
+ @st.cache(show_spinner=False, persist=True)
11
+ def load_model(masked_text, model_name):
12
 
13
  model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
16
 
17
  result_sentence = nlp(masked_text)
18
 
19
+ return result_sentence
20
+
21
 
22
  def app():
 
23
  st.markdown(
24
+ "<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>",
25
+ unsafe_allow_html=True,
26
+ )
27
+ st.markdown(
28
+ "This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
29
  )
30
 
31
+ target_text_path = "./mlm_custom/mlm_targeted_text.csv"
32
  target_text_df = pd.read_csv(target_text_path)
33
+
34
+ texts = target_text_df["text"]
35
+
36
  st.markdown("""## Select any of the following text : """)
37
+ masked_text = st.selectbox("", texts)
 
38
 
39
+ st.write("You selected:", masked_text)
40
 
41
  models = st.multiselect(
42
  "Choose models",
43
+ [
44
+ "flax-community/roberta-hindi",
45
+ "mrm8488/HindiBERTa",
46
+ "ai4bharat/indic-bert",
47
+ "neuralspace-reverie/indic-transformers-hi-bert",
48
+ "surajp/RoBERTa-hindi-guj-san",
49
+ ],
50
+ ["flax-community/roberta-hindi"],
51
+ )
52
 
53
  selected_model = models[0]
54
 
55
+ if st.button("Fill the Mask!"):
56
+ with st.spinner("Filling the Mask..."):
57
+ filled_sentence = load_model(masked_text, selected_model)
58
  st.write(filled_sentence)