Galuh Sahid commited on
Commit
7c82098
1 Parent(s): 7be8469

fix reloading

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -14,7 +14,7 @@ LOGO = "huggingwayang.png"
14
  MODELS = {
15
  "GPT-2 Small": "flax-community/gpt2-small-indonesian",
16
  "GPT-2 Medium": "flax-community/gpt2-medium-indonesian",
17
- "GPT-2 Small Finetuned on Indonesian Journals": "Galuh/id-journal-gpt2"
18
  }
19
 
20
  headers = {}
@@ -96,11 +96,11 @@ st.markdown(
96
  """
97
  )
98
 
99
- model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small Finetuned on Indonesian Journals']))
100
 
101
  if model_name in ["GPT-2 Small", "GPT-2 Medium"]:
102
  prompt_group_name = "GPT-2"
103
- elif model_name in ["GPT-2 Small Finetuned on Indonesian Journals"]:
104
  prompt_group_name = "Indonesian Journals"
105
 
106
  ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
@@ -118,6 +118,8 @@ session_state.prompt_box = prompt_box
118
  text = st.text_area("Enter text", session_state.prompt_box)
119
 
120
  if st.button("Run"):
 
 
121
  with st.spinner(text="Getting results..."):
122
  lang_predictions, lang_probability = ft_model.predict(text.replace("\n", " "), k=3)
123
  if "__label__id" in lang_predictions:
 
14
  MODELS = {
15
  "GPT-2 Small": "flax-community/gpt2-small-indonesian",
16
  "GPT-2 Medium": "flax-community/gpt2-medium-indonesian",
17
+ "GPT-2 Small finetuned on Indonesian academic journals": "Galuh/id-journal-gpt2"
18
  }
19
 
20
  headers = {}
 
96
  """
97
  )
98
 
99
+ model_name = st.selectbox('Model',(['GPT-2 Small', 'GPT-2 Medium', 'GPT-2 Small finetuned on Indonesian academic journals']))
100
 
101
  if model_name in ["GPT-2 Small", "GPT-2 Medium"]:
102
  prompt_group_name = "GPT-2"
103
+ elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]:
104
  prompt_group_name = "Indonesian Journals"
105
 
106
  ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
 
118
  text = st.text_area("Enter text", session_state.prompt_box)
119
 
120
  if st.button("Run"):
121
+ text = st.text_area("Enter text", session_state.prompt_box)
122
+
123
  with st.spinner(text="Getting results..."):
124
  lang_predictions, lang_probability = ft_model.predict(text.replace("\n", " "), k=3)
125
  if "__label__id" in lang_predictions: