import transformers import torch import streamlit as st from transformers import BertTokenizer st.markdown("### Из какой области статья? Введите название и аннотация научной статьи и я попробую угадать из какой она области)") # link = 'https://www.clipartmax.com/png/middle/87-873210_akinator-with-transparent-background.png' # st.markdown(f"", unsafe_allow_html=True) # st.markdown("", unsafe_allow_html=True) # from transformers import # pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl") num_classes = 8 class BERTClass(torch.nn.Module): def __init__(self, n_hid1 = 1024, n_out=num_classes, bert_path='bert-base-uncased'): super(BERTClass, self).__init__() self.l1 = transformers.BertModel.from_pretrained(bert_path) self.l2 = torch.nn.Dropout(0.3) self.l3 = torch.nn.Linear(768, n_hid1) self.l4 = torch.nn.ReLU() self.l5 = torch.nn.Dropout(0.2) self.l6 = torch.nn.Linear(n_hid1, n_out) def forward(self, ids, mask, token_type_ids): # _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) out = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) out = self.l2(out[1]) out = self.l3(out) out = self.l4(out) out = self.l5(out) out = self.l6(out) return out @st.cache(suppress_st_warning=True, allow_output_mutation=True) def load_bert(): model = BERTClass(bert_path='bert_pretrained') model.load_state_dict(torch.load('bert_pretrained.pt')) model.eval() tokenizer = BertTokenizer.from_pretrained('bert_tokenizer') return model, tokenizer def apply_bert(text, model, tokenizer): """returns probabilities""" MAX_LEN = 200 ins = tokenizer.encode_plus(text, None, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length=True, return_token_type_ids=True ) ids = torch.tensor(ins['input_ids']).unsqueeze(0) mask = torch.tensor(ins['attention_mask']).unsqueeze(0) token_type_ids = torch.tensor(ins["token_type_ids"]) out = model(ids, mask, token_type_ids) return torch.sigmoid(out).flatten().detach() class TinyBERTClass(torch.nn.Module): def __init__(self, n_hid1 = 1024, n_out=num_classes, path='distilbert-base-uncased'): super(TinyBERTClass, self).__init__() self.l1 = transformers.DistilBertModel.from_pretrained(path) self.l2 = torch.nn.Dropout(0.3) self.l3 = torch.nn.Linear(768, n_hid1) self.l4 = torch.nn.ReLU() self.l5 = torch.nn.Dropout(0.2) self.l6 = torch.nn.Linear(n_hid1, n_out) def forward(self, ids, mask): # _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids) out = self.l1(ids, attention_mask = mask) out = self.l2(out.last_hidden_state[:,0,:]) out = self.l3(out) out = self.l4(out) out = self.l5(out) out = self.l6(out) return out @st.cache(suppress_st_warning=True, allow_output_mutation=True) def load_tiny_bert(): model = TinyBERTClass(path = 'tiny_bert_pretrained') model.load_state_dict(torch.load('tiny_bert.pt')) model.eval() tokenizer = transformers.DistilBertTokenizer.from_pretrained('tiny_bert_tokenizer') return model, tokenizer def apply_tiny_bert(text, model, tokenizer): encoded_input = tokenizer(text, return_tensors='pt') out = model(encoded_input['input_ids'], encoded_input['attention_mask']) return torch.sigmoid(out).flatten().detach() title = st.text_area("Название статьи") if not title.endswith('.') and title: title += '.' summary = st.text_area("Аннотация статьи") calc_button = st.button('Угадать тематику') bert_model, bert_tokenizer = load_bert() tiny_bert, tiny_bert_tokenizer = load_tiny_bert() # calculate ================================================================ if calc_button: print('title') print(title) print('=' * 80) # print(text) if summary: text = title + summary out = apply_bert(text, bert_model, bert_tokenizer) else: out = apply_tiny_bert(title, tiny_bert, tiny_bert_tokenizer) RU_NAMES = ['компьютерным наукам' ,'экономике' ,'электротехнике и системотехнике' ,'математике' ,'физике' ,'количественной биологии' ,'количественным финансам' ,'статистике' ] def get_classes(out, bandwidth = 0.5): res = [] for i in range(out.size()[0]): if out[i] >= bandwidth: res.append(i) ans = '' total = 0 for i in res: total += out[i].item() if not ans: ans += f'\nэто статья по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}' else: ans += f',\nтакже она по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}' ans = 'Э' + ans[2:] if total >= 1.0: ans += '.\n(Решалась задача мультиклассификации, поэтому сумма вероятностей получилась больше 1.)' if ans == 'Э': return 'Не похоже на что-то научное, Вы уверены что это взято из статьи?' return ans res = get_classes(out) st.markdown(f"{res}")