Galuh Sahid commited on
Commit
9554cf5
1 Parent(s): 6dd057c

fix reloading

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -103,26 +103,26 @@ if model_name in ["GPT-2 Small", "GPT-2 Medium"]:
103
  elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]:
104
  prompt_group_name = "Indonesian Journals"
105
 
106
- ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
107
- prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
108
 
109
- session_state = SessionState.get(prompt_box=None)
 
110
 
111
- if prompt == "Custom":
112
  session_state.prompt_box = "Enter your text here"
113
  else:
114
- session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][prompt])
115
 
116
- text = st.text_area("Enter text", session_state.prompt_box)
117
 
118
  if st.button("Run"):
119
  with st.spinner(text="Getting results..."):
120
- lang_predictions, lang_probability = ft_model.predict(text.replace("\n", " "), k=3)
121
  if "__label__id" in lang_predictions:
122
  lang = "id"
123
  else:
124
  lang = lang_predictions[0].replace("__label__", "")
125
- text = translate(text, "id", lang)
126
 
127
  st.subheader("Result")
128
  model = load_gpt(model_name)
 
103
  elif model_name in ["GPT-2 Small finetuned on Indonesian academic journals"]:
104
  prompt_group_name = "Indonesian Journals"
105
 
106
+ session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
 
107
 
108
+ ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
109
+ session_state.prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
110
 
111
+ if session_state.prompt == "Custom":
112
  session_state.prompt_box = "Enter your text here"
113
  else:
114
+ session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])
115
 
116
+ session_state.text = st.text_area("Enter text", session_state.prompt_box)
117
 
118
  if st.button("Run"):
119
  with st.spinner(text="Getting results..."):
120
+ lang_predictions, lang_probability = ft_model.predict(session_state.text.replace("\n", " "), k=3)
121
  if "__label__id" in lang_predictions:
122
  lang = "id"
123
  else:
124
  lang = lang_predictions[0].replace("__label__", "")
125
+ text = translate(session_state.text, "id", lang)
126
 
127
  st.subheader("Result")
128
  model = load_gpt(model_name)