mariagrandury commited on
Commit
832ee1c
1 Parent(s): 3954682

Update application

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -8,22 +8,22 @@ LOGO = "https://raw.githubusercontent.com/nlp-en-es/assets/main/logo.png"
8
 
9
  MODELS = {
10
  "RoBERTa Base": {
11
- "url": "https://api-inference.huggingface.co/models/bertin-project/bertin-roberta-base-spanish"
12
  },
13
  "RoBERTa Base Gaussian": {
14
- "url": "https://api-inference.huggingface.co/models/bertin-project/bertin-base-gaussian"
15
  },
16
  "RoBERTa Base Random": {
17
- "url": "https://api-inference.huggingface.co/models/bertin-project/bertin-base-random"
18
  },
19
  "RoBERTa Base Stepwise": {
20
- "url": "https://api-inference.huggingface.co/models/bertin-project/bertin-base-stepwise"
21
  },
22
  "RoBERTa Base Gaussian Experiment": {
23
- "url": "https://api-inference.huggingface.co/models/bertin-project/bertin-base-gaussian-exp-512seqlen"
24
  },
25
  "RoBERTa Base Random Experiment": {
26
- "url": "https://api-inference.huggingface.co/models/bertin-project/bertin-base-random-exp-512seqlen"
27
  }
28
  }
29
 
@@ -41,9 +41,9 @@ PROMPT_LIST = [
41
 
42
 
43
  @st.cache(show_spinner=False, persist=True)
44
- def load_model(masked_text, model_name):
45
- model = AutoModelForMaskedLM.from_pretrained(model_name)
46
- tokenizer = AutoTokenizer.from_pretrained(model_name)
47
  nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
48
  result = nlp(masked_text)
49
  return result
@@ -66,7 +66,8 @@ st.markdown(
66
  """
67
  )
68
 
69
- model_name = st.selectbox("Model",(MODELS.keys()))
 
70
 
71
  prompt = st.selectbox("Prompt", ["Random", "Custom"])
72
  if prompt == "Custom":
@@ -78,7 +79,7 @@ text = st.text_area("Enter text", prompt_box)
78
  if st.button("Fill the mask"):
79
  with st.spinner(text="Getting results..."):
80
  st.subheader("Result")
81
- result = load_model(text, model_name)
82
  if "error" in result:
83
  if type(result["error"]) is str:
84
  st.write(f'{result["error"]}.', end=" ")
 
8
 
9
  MODELS = {
10
  "RoBERTa Base": {
11
+ "url": "bertin-project/bertin-roberta-base-spanish"
12
  },
13
  "RoBERTa Base Gaussian": {
14
+ "url": "bertin-project/bertin-base-gaussian"
15
  },
16
  "RoBERTa Base Random": {
17
+ "url": "bertin-project/bertin-base-random"
18
  },
19
  "RoBERTa Base Stepwise": {
20
+ "url": "bertin-project/bertin-base-stepwise"
21
  },
22
  "RoBERTa Base Gaussian Experiment": {
23
+ "url": "bertin-project/bertin-base-gaussian-exp-512seqlen"
24
  },
25
  "RoBERTa Base Random Experiment": {
26
+ "url": "bertin-project/bertin-base-random-exp-512seqlen"
27
  }
28
  }
29
 
 
41
 
42
 
43
  @st.cache(show_spinner=False, persist=True)
44
+ def load_model(masked_text, model_url):
45
+ model = AutoModelForMaskedLM.from_pretrained(model_url)
46
+ tokenizer = AutoTokenizer.from_pretrained(model_url)
47
  nlp = pipeline("fill-mask", model=model, tokenizer=tokenizer)
48
  result = nlp(masked_text)
49
  return result
 
66
  """
67
  )
68
 
69
+ model_name = st.selectbox("Model",MODELS.keys())
70
+ model_url = MODELS[model_name]["url"]
71
 
72
  prompt = st.selectbox("Prompt", ["Random", "Custom"])
73
  if prompt == "Custom":
 
79
  if st.button("Fill the mask"):
80
  with st.spinner(text="Getting results..."):
81
  st.subheader("Result")
82
+ result = load_model(text, model_url)
83
  if "error" in result:
84
  if type(result["error"]) is str:
85
  st.write(f'{result["error"]}.', end=" ")