poseg's picture
remove akinator
2a1e3a9
import transformers
import torch
import streamlit as st
from transformers import BertTokenizer
st.markdown("### Из какой области статья? Введите название и аннотация научной статьи и я попробую угадать из какой она области)")
# link = 'https://www.clipartmax.com/png/middle/87-873210_akinator-with-transparent-background.png'
# st.markdown(f"<img width=200px src='{link}'>", unsafe_allow_html=True)
# st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
# from transformers import
# pipe = pipeline("ner", "Davlan/distilbert-base-multilingual-cased-ner-hrl")
num_classes = 8
class BERTClass(torch.nn.Module):
def __init__(self, n_hid1 = 1024, n_out=num_classes, bert_path='bert-base-uncased'):
super(BERTClass, self).__init__()
self.l1 = transformers.BertModel.from_pretrained(bert_path)
self.l2 = torch.nn.Dropout(0.3)
self.l3 = torch.nn.Linear(768, n_hid1)
self.l4 = torch.nn.ReLU()
self.l5 = torch.nn.Dropout(0.2)
self.l6 = torch.nn.Linear(n_hid1, n_out)
def forward(self, ids, mask, token_type_ids):
# _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
out = self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
out = self.l2(out[1])
out = self.l3(out)
out = self.l4(out)
out = self.l5(out)
out = self.l6(out)
return out
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_bert():
model = BERTClass(bert_path='bert_pretrained')
model.load_state_dict(torch.load('bert_pretrained.pt'))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert_tokenizer')
return model, tokenizer
def apply_bert(text, model, tokenizer):
"""returns probabilities"""
MAX_LEN = 200
ins = tokenizer.encode_plus(text, None, add_special_tokens=True,
max_length=MAX_LEN,
pad_to_max_length=True,
return_token_type_ids=True
)
ids = torch.tensor(ins['input_ids']).unsqueeze(0)
mask = torch.tensor(ins['attention_mask']).unsqueeze(0)
token_type_ids = torch.tensor(ins["token_type_ids"])
out = model(ids, mask, token_type_ids)
return torch.sigmoid(out).flatten().detach()
class TinyBERTClass(torch.nn.Module):
def __init__(self, n_hid1 = 1024, n_out=num_classes, path='distilbert-base-uncased'):
super(TinyBERTClass, self).__init__()
self.l1 = transformers.DistilBertModel.from_pretrained(path)
self.l2 = torch.nn.Dropout(0.3)
self.l3 = torch.nn.Linear(768, n_hid1)
self.l4 = torch.nn.ReLU()
self.l5 = torch.nn.Dropout(0.2)
self.l6 = torch.nn.Linear(n_hid1, n_out)
def forward(self, ids, mask):
# _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
out = self.l1(ids, attention_mask = mask)
out = self.l2(out.last_hidden_state[:,0,:])
out = self.l3(out)
out = self.l4(out)
out = self.l5(out)
out = self.l6(out)
return out
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def load_tiny_bert():
model = TinyBERTClass(path = 'tiny_bert_pretrained')
model.load_state_dict(torch.load('tiny_bert.pt'))
model.eval()
tokenizer = transformers.DistilBertTokenizer.from_pretrained('tiny_bert_tokenizer')
return model, tokenizer
def apply_tiny_bert(text, model, tokenizer):
encoded_input = tokenizer(text, return_tensors='pt')
out = model(encoded_input['input_ids'], encoded_input['attention_mask'])
return torch.sigmoid(out).flatten().detach()
title = st.text_area("Название статьи")
if not title.endswith('.') and title:
title += '.'
summary = st.text_area("Аннотация статьи")
calc_button = st.button('Угадать тематику')
bert_model, bert_tokenizer = load_bert()
tiny_bert, tiny_bert_tokenizer = load_tiny_bert()
# calculate ================================================================
if calc_button:
print('title')
print(title)
print('=' * 80)
# print(text)
if summary:
text = title + summary
out = apply_bert(text, bert_model, bert_tokenizer)
else:
out = apply_tiny_bert(title, tiny_bert, tiny_bert_tokenizer)
RU_NAMES = ['компьютерным наукам'
,'экономике'
,'электротехнике и системотехнике'
,'математике'
,'физике'
,'количественной биологии'
,'количественным финансам'
,'статистике'
]
def get_classes(out, bandwidth = 0.5):
res = []
for i in range(out.size()[0]):
if out[i] >= bandwidth:
res.append(i)
ans = ''
total = 0
for i in res:
total += out[i].item()
if not ans:
ans += f'\nэто статья по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}'
else:
ans += f',\nтакже она по {RU_NAMES[i]} с вероятностью {out[i].item():.2f}'
ans = 'Э' + ans[2:]
if total >= 1.0:
ans += '.\n(Решалась задача мультиклассификации, поэтому сумма вероятностей получилась больше 1.)'
if ans == 'Э':
return 'Не похоже на что-то научное, Вы уверены что это взято из статьи?'
return ans
res = get_classes(out)
st.markdown(f"{res}")