Dzhamb commited on
Commit
7c5a283
1 Parent(s): 541cdd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -1,17 +1,27 @@
1
  import streamlit as st
 
 
 
 
 
2
 
3
  st.markdown("## Классификатор статей")
4
  st.markdown("Сервис классифицирует статьи по названию и аннотации. Нужно ввести в каждое окошко свою сущность и вам выдадут к какому классу относится статья")
5
- # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter
6
 
7
  title = st.text_area("Введите название статьи")
8
- # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент
9
  abstract = st.text_area("Введите аннотацию к статье, abstract статьи")
10
 
11
- from transformers import pipeline
12
- pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
13
- raw_predictions = pipe(title + ' ' + abstract)
14
- # тут уже знакомый вам код с huggingface.transformers -- его можно заменить на что угодно от fairseq до catboost
 
 
 
 
 
 
 
15
 
16
- st.markdown(f"{raw_predictions}")
17
- # выводим результаты модели в текстовое поле, на потеху пользователю
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
+ from utils import get_text, get_label
5
+
6
+ count_labels = 8
7
 
8
  st.markdown("## Классификатор статей")
9
  st.markdown("Сервис классифицирует статьи по названию и аннотации. Нужно ввести в каждое окошко свою сущность и вам выдадут к какому классу относится статья")
10
+
11
 
12
  title = st.text_area("Введите название статьи")
13
+
14
  abstract = st.text_area("Введите аннотацию к статье, abstract статьи")
15
 
16
+ tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
17
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=count_labels)
18
+ model.load_state_dict(torch.load('weight_model'))
19
+
20
+ text = get_text(title, abstract)
21
+ if text:
22
+ raw_predictions = get_label(text)
23
+ st.markdown(f"{raw_predictions}")
24
+ else:
25
+ st.markdown("Ваш запрос пуст. Введите хотя бы название")
26
+
27