prateekagrawal commited on
Commit
2fe741d
1 Parent(s): 41f5a3a

Updated requirements.txt

Browse files
Files changed (2) hide show
  1. apps/inference.py +3 -3
  2. requirements.txt +3 -1
apps/inference.py CHANGED
@@ -11,13 +11,13 @@ def load_model(masked_text,model_name):
11
 
12
  model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- tokenizer.save_pretrained('exported_pytorch_model')
15
- model.save_pretrained('exported_pytorch_model')
16
  nlp = pipeline('fill-mask', model="exported_pytorch_model")
17
 
18
  result_sentence = nlp(masked_text)
19
 
20
- return result_sentence
21
 
22
  def app():
23
  st.markdown("<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>", unsafe_allow_html=True)
 
11
 
12
  model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ # tokenizer.save_pretrained('exported_pytorch_model')
15
+ # model.save_pretrained('exported_pytorch_model')
16
  nlp = pipeline('fill-mask', model="exported_pytorch_model")
17
 
18
  result_sentence = nlp(masked_text)
19
 
20
+ return result_sentence[0]['sequence']
21
 
22
  def app():
23
  st.markdown("<h1 style='text-align: center; color: green;'>RoBERTa Hindi</h1>", unsafe_allow_html=True)
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  streamlit
2
  torch
3
  transformers
4
- jax
 
 
 
1
  streamlit
2
  torch
3
  transformers
4
+ jax
5
+ jaxlib
6
+ flax