hw-devops / app.py
fed3rikko's picture
Update app.py
665a996
raw
history blame contribute delete
No virus
2.31 kB
import streamlit as st
from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax
base_model_name = 'distilbert-base-uncased'
@st.cache
def load_tags_info():
id_to_description = {}
with open('tags.txt', 'r') as file:
i = 0
for line in file:
description = line[:-1]
id_to_description[i] = description
i += 1
return id_to_description
id_to_description = load_tags_info()
@st.cache
def load_model():
return DistilBertForSequenceClassification.from_pretrained('./')
def load_tokenizer():
return AutoTokenizer.from_pretrained('distilbert-base-uncased')
def top_xx(preds, xx=95):
tops = torch.argsort(preds, 1, descending=True)
total = 0
index = 0
result = []
while total < xx / 100:
next_id = tops[0, index].item()
total += preds[0, next_id]
index += 1
result.append(id_to_description[next_id])
return result
model = load_model()
tokenizer = load_tokenizer()
temperature = 1
st.title('ArXivTaxonomizer')
st.caption('Напишите тему(Title) и параграф из статьи(Abstract). Поля должны быть непустыми для корректной классификации.')
with st.form("Taxonomizer"):
title = st.text_area(label='Title', height=30)
abstract = st.text_area(label='Abstract (optional)', height=200)
st.caption('Будут выведеты темы в порядке от наибольшей вероятности до наименьшей')
submitted = st.form_submit_button("Taxonomize")
if submitted:
if title == '':
st.markdown("Нужно хоть что-то написатб")
else:
prompt = 'Title: ' + title + ' Abstract: ' + abstract
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
tags = top_xx(preds)
other_tags = []
st.header('Inferred tags:')
for i, tag_data in enumerate(tags):
st.markdown('* ' + tag_data)