import streamlit as st from transformers import AutoTokenizer, DistilBertForSequenceClassification import torch from torch.nn.functional import softmax base_model_name = 'distilbert-base-uncased' @st.cache def load_tags_info(): id_to_description = {} with open('tags.txt', 'r') as file: i = 0 for line in file: description = line[:-1] id_to_description[i] = description i += 1 return id_to_description id_to_description = load_tags_info() @st.cache def load_model(): return DistilBertForSequenceClassification.from_pretrained('./') def load_tokenizer(): return AutoTokenizer.from_pretrained('distilbert-base-uncased') def top_xx(preds, xx=95): tops = torch.argsort(preds, 1, descending=True) total = 0 index = 0 result = [] while total < xx / 100: next_id = tops[0, index].item() total += preds[0, next_id] index += 1 result.append(id_to_description[next_id]) return result model = load_model() tokenizer = load_tokenizer() temperature = 1 st.title('ArXivTaxonomizer') st.caption('Напишите тему(Title) и параграф из статьи(Abstract). Поля должны быть непустыми для корректной классификации.') with st.form("Taxonomizer"): title = st.text_area(label='Title', height=30) abstract = st.text_area(label='Abstract (optional)', height=200) st.caption('Будут выведеты темы в порядке от наибольшей вероятности до наименьшей') submitted = st.form_submit_button("Taxonomize") if submitted: if title == '': st.markdown("Нужно хоть что-то написатб") else: prompt = 'Title: ' + title + ' Abstract: ' + abstract tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1) tags = top_xx(preds) other_tags = [] st.header('Inferred tags:') for i, tag_data in enumerate(tags): st.markdown('* ' + tag_data)