mlkorra commited on
Commit
d36be85
2 Parent(s): 378e869 90dc487

Merge branch 'main' of https://huggingface.co/spaces/flax-community/roberta-hindi into main

Browse files
Files changed (1) hide show
  1. app.py +8 -10
app.py CHANGED
@@ -11,19 +11,19 @@ def load_model(masked_text,model_name):
11
 
12
  model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- tokenizer.save_pretrained('exported_pytorch_model')
15
- model.save_pretrained('exported_pytorch_model')
16
- nlp = pipeline('fill-mask', model="exported_pytorch_model")
17
 
18
  result_sentence = nlp(masked_text)
19
 
20
- return result_sentence[0]['sequence']
21
 
22
  def main():
23
 
24
- st.title("RoBERTa-Hindi")
25
  st.markdown(
26
- "This demo uses pretrained RoBERTa variants for Mask Language Modelling (MLM)"
27
  )
28
 
29
  models = st.multiselect(
@@ -48,11 +48,9 @@ def main():
48
  selected_model = models[0]
49
 
50
  if st.button('Fill the Mask!'):
51
-
52
- with st.spinner("Filling ... "):
53
  filled_sentence = load_model(masked_text,selected_model)
54
-
55
- st.write(filled_sentence)
56
 
57
 
58
  if __name__ == "__main__":
11
 
12
  model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ # tokenizer.save_pretrained('exported_pytorch_model')
15
+ # model.save_pretrained('exported_pytorch_model')
16
+ nlp = pipeline('fill-mask', model=model, tokenizer=tokenizer)
17
 
18
  result_sentence = nlp(masked_text)
19
 
20
+ return result_sentence
21
 
22
  def main():
23
 
24
+ st.title("RoBERTa Hindi")
25
  st.markdown(
26
+ "This demo uses pretrained RoBERTa variants for Mask Language Modeling (MLM)"
27
  )
28
 
29
  models = st.multiselect(
48
  selected_model = models[0]
49
 
50
  if st.button('Fill the Mask!'):
51
+ with st.spinner("Filling the Mask..."):
 
52
  filled_sentence = load_model(masked_text,selected_model)
53
+ st.write(filled_sentence)
 
54
 
55
 
56
  if __name__ == "__main__":