prateekagrawal commited on
Commit
c320b57
1 Parent(s): e668d73

Updated flax error in inference

Browse files
Files changed (1) hide show
  1. apps/inference.py +6 -2
apps/inference.py CHANGED
@@ -15,7 +15,11 @@ predicted_sentence = []
15
  @st.cache(show_spinner=False, persist=True)
16
  def load_model(masked_text, model_name):
17
 
18
- model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=True)
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
  nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
21
 
@@ -59,7 +63,7 @@ def app():
59
 
60
  for i in range(len(selected_models)):
61
  filled_sentence = load_model(masked_text, selected_models[i])
62
- st.write(filled_sentence)
63
  models.append(selected_models[i])
64
  predicted_tokens.append(filled_sentence[0]["token_str"])
65
  predicted_sentence.append(filled_sentence[0]["sequence"])
 
15
  @st.cache(show_spinner=False, persist=True)
16
  def load_model(masked_text, model_name):
17
 
18
+ from_flax = False
19
+ if model_name == "flax-community/roberta-hindi":
20
+ from_flax = True
21
+
22
+ model = AutoModelForMaskedLM.from_pretrained(model_name, from_flax=from_flax)
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
  nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
25
 
 
63
 
64
  for i in range(len(selected_models)):
65
  filled_sentence = load_model(masked_text, selected_models[i])
66
+ # st.write(filled_sentence)
67
  models.append(selected_models[i])
68
  predicted_tokens.append(filled_sentence[0]["token_str"])
69
  predicted_sentence.append(filled_sentence[0]["sequence"])