import streamlit as st import torch @st.cache def Model(): from transformers import DebertaTokenizer, DebertaForSequenceClassification tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base") model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base", num_labels=8) bn_state_dict = torch.load('model_weights.pt') model.load_state_dict(bn_state_dict) return model, tokenizer def Predict(model, tokenizer, text): res = tokenizer(s, padding=True, truncation=True, return_tensors="pt", max_length=512) #var.to("cuda:0") res = model(**res) logits = res.logits.softmax(dim=1) logits = logits.numpy()[0]#logits.cpu().detach().numpy()[0] return logits def Print(logits, dictionary): z = zip(logits, np.arange(0, 8)) z = sorted(z, key=lambda x: x[0], reverse=True) sum, idx = 0, 0 while sum < 0.95: st.markdown(f"{idx + 1}. ", dictionary[z[idx][1]]) sum += z[idx][0] idx += 1 def filter(title, abstract): return True st.title('Классификация статьи по названию и описанию') # ^-- можно показывать пользователю текст, картинки, ограниченное подмножество html - всё как в jupyter title = st.text_area("Введите название статьи:") abstract = st.text_area("Введите описание статьи:") # ^-- показать текстовое поле. В поле text лежит строка, которая находится там в данный момент text = title + '. ' + abstract dictionary = ['computer science', 'economics', 'Electrical Engineering and Systems Science', 'math', 'physics', 'quantitative biology', 'quantitative finance', 'statistics'] if filter(title, abstract): model, tokenizer = Model() logits = Predict(model, tokenizer, text) Print(logits, dictionary)