MikhailPugachev commited on
Commit
0c47f30
·
1 Parent(s): bb1ee88

Исправлен путь к модели

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -4,18 +4,21 @@ import torch.nn.functional as F
4
  from transformers import AutoTokenizer
5
  from model_SingleLabelClassifier import SingleLabelClassifier
6
  from safetensors.torch import load_file
 
 
7
 
8
- # --- Настройки ---
9
  MODEL_NAME = "allenai/scibert_scivocab_uncased"
10
- CHECKPOINT_PATH = "checkpoint-28553"
11
- NUM_CLASSES = 7
12
- MAX_LEN = 320
13
 
14
- # --- Загрузка меток ---
15
- label2id = {'cs.CV': 0, 'cs.LG': 1, 'cs.AI': 2, 'cs.CL': 3, 'stat.ML': 4, 'cs.NE': 5, '<OTHER>': 6}
16
- id2label = {v: k for k, v in label2id.items()}
 
 
17
 
18
- # --- Загрузка модели и токенизатора ---
19
  @st.cache_resource
20
  def load_model_and_tokenizer():
21
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
@@ -27,7 +30,7 @@ def load_model_and_tokenizer():
27
 
28
  model, tokenizer = load_model_and_tokenizer()
29
 
30
- # --- Функция предсказания ---
31
  def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3):
32
  model.eval()
33
  text = title + ". " + summary
@@ -48,9 +51,9 @@ def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3)
48
  top_indices = probs.argsort()[::-1][:top_k]
49
  return [(id2label[i], round(probs[i], 3)) for i in top_indices]
50
 
51
- # --- Интерфейс Streamlit ---
52
  st.title("ArXiv Tag Predictor")
53
- st.write("Вставьте заголовок и аннотацию статьи — получите предсказанный тег!")
54
 
55
  title = st.text_input("**Title**")
56
  summary = st.text_area("**Summary**", height=200)
 
4
  from transformers import AutoTokenizer
5
  from model_SingleLabelClassifier import SingleLabelClassifier
6
  from safetensors.torch import load_file
7
+ import json
8
+
9
 
 
10
  MODEL_NAME = "allenai/scibert_scivocab_uncased"
11
+ CHECKPOINT_PATH = "checkpoint-23985"
12
+ NUM_CLASSES = 65
13
+ MAX_LEN = 325020
14
 
15
+ # Загрузка меток
16
+ with open("label_mappings.json", "r") as f:
17
+ mappings = json.load(f)
18
+ abel2id = mappings["label2id"]
19
+ id2label = {int(k): v for k, v in mappings["id2label"].items()}
20
 
21
+ # Загрузка модели и токенизатора
22
  @st.cache_resource
23
  def load_model_and_tokenizer():
24
  tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
 
30
 
31
  model, tokenizer = load_model_and_tokenizer()
32
 
33
+ # Функция предсказания
34
  def predict(title, summary, model, tokenizer, id2label, max_length=320, top_k=3):
35
  model.eval()
36
  text = title + ". " + summary
 
51
  top_indices = probs.argsort()[::-1][:top_k]
52
  return [(id2label[i], round(probs[i], 3)) for i in top_indices]
53
 
54
+ # Интерфейс Streamlit
55
  st.title("ArXiv Tag Predictor")
56
+ st.write("Вставьте заголовок и аннотацию статьи!")
57
 
58
  title = st.text_input("**Title**")
59
  summary = st.text_area("**Summary**", height=200)