adorkin commited on
Commit
e5b30dc
1 Parent(s): 8bcc77d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -8
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import torch
5
 
6
  TOP_N = 5
 
7
 
8
  def preprocess(text):
9
  new_text = []
@@ -25,7 +26,11 @@ def get_top_emojis(text, tokenizer, model, top_n=TOP_N):
25
  return '\t'.join(map(str, emojis))
26
 
27
  def main():
28
-
 
 
 
 
29
 
30
  st.set_page_config( # Alternate names: setup_page, page, layout
31
  layout="centered", # Can be "centered" or "wide". In the future also "dashboard", etc.
@@ -53,10 +58,12 @@ def main():
53
  "AlekseyDorkin/xlm-roberta-en-ru-emoji"
54
  ]
55
 
56
- BASE_MODEL = st.selectbox("Choose a model", models_to_choose)
57
-
58
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
59
- model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL)
 
 
60
 
61
  # Define function to run when submit is clicked
62
  def submit(message):
@@ -70,9 +77,11 @@ def main():
70
  submit(message)
71
 
72
  st.text('')
73
- st.markdown('''<span style="color:blue; font-size:10px">App created by [@AlekseyDorkin](https://huggingface.co/AlekseyDorkin)
74
- and [@akshay7](https://huggingface.co/akshay7)</span>''',
75
- unsafe_allow_html=True)
 
 
76
 
77
 
78
  if __name__ == "__main__":
 
4
  import torch
5
 
6
  TOP_N = 5
7
+ DEFAULT_MODEL = "amazon-sagemaker-community/xlm-roberta-en-ru-emoji-v2"
8
 
9
  def preprocess(text):
10
  new_text = []
 
26
  return '\t'.join(map(str, emojis))
27
 
28
  def main():
29
+
30
+ cur_model_name = DEFAULT_MODEL
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(cur_model_name)
33
+ model = AutoModelForSequenceClassification.from_pretrained(cur_model_name)
34
 
35
  st.set_page_config( # Alternate names: setup_page, page, layout
36
  layout="centered", # Can be "centered" or "wide". In the future also "dashboard", etc.
 
58
  "AlekseyDorkin/xlm-roberta-en-ru-emoji"
59
  ]
60
 
61
+ model_name = st.selectbox("Choose a model", models_to_choose)
62
+ if model_name != cur_model_name:
63
+ cur_model_name = model_name
64
+ tokenizer = AutoTokenizer.from_pretrained(cur_model_name)
65
+ model = AutoModelForSequenceClassification.from_pretrained(cur_model_name)
66
+
67
 
68
  # Define function to run when submit is clicked
69
  def submit(message):
 
77
  submit(message)
78
 
79
  st.text('')
80
+ st.markdown(
81
+ '''<span style="color:blue; font-size:10px">App created by [@AlekseyDorkin](https://huggingface.co/AlekseyDorkin)
82
+ and [@akshay7](https://huggingface.co/akshay7)</span>''',
83
+ unsafe_allow_html=True,
84
+ )
85
 
86
 
87
  if __name__ == "__main__":