ilyshi commited on
Commit
88367f6
1 Parent(s): 3286385
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -29,40 +29,35 @@ def make_prediction(text):
29
  st.markdown("### Category prediction:")
30
  print_probs(pred_logits[0])
31
 
32
-
33
- @st.cache
34
- def model_init():
35
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
36
-
37
- model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
38
- model_name = "trained_model2"
39
- model_path = model_name + '.zip'
40
- model.load_state_dict(
41
- torch.load(
42
- model_path,
43
- map_location=torch.device("cpu")
44
- )
45
- )
46
- return tokenizer, model
47
 
48
  # MAIN
49
  from PIL import Image
50
  image = Image.open('logo.png')
51
 
52
  st.image(image)
53
-
54
- # st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg' class='center'>", unsafe_allow_html=True)
55
- # st.markdown("# Arxiv.org category classifier")
56
  st.markdown("# ")
57
-
58
  st.markdown("### Article Title")
 
59
  text1 = st.text_area("Введите название научной статьи для классификации", height=20)
60
 
61
  st.markdown("### Article Abstract")
62
 
63
  text2 = st.text_area("Введите описание статьи", height=200)
 
64
  common_text = text1 + text2
 
65
  if common_text != "":
66
- tokenizer, model = model_init()
67
  make_prediction(common_text)
68
 
 
29
  st.markdown("### Category prediction:")
30
  print_probs(pred_logits[0])
31
 
32
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
33
+
34
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
35
+ model_name = "trained_model2"
36
+ model_path = model_name + '.zip'
37
+ model.load_state_dict(
38
+ torch.load(
39
+ model_path,
40
+ map_location=torch.device("cpu")
41
+ )
42
+ )
 
 
 
 
43
 
44
  # MAIN
45
  from PIL import Image
46
  image = Image.open('logo.png')
47
 
48
  st.image(image)
 
 
 
49
  st.markdown("# ")
 
50
  st.markdown("### Article Title")
51
+
52
  text1 = st.text_area("Введите название научной статьи для классификации", height=20)
53
 
54
  st.markdown("### Article Abstract")
55
 
56
  text2 = st.text_area("Введите описание статьи", height=200)
57
+
58
  common_text = text1 + text2
59
+
60
  if common_text != "":
 
61
  make_prediction(common_text)
62
 
63
+