Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import datasets | |
| import torch | |
| model_name = 'distilbert-base-cased' | |
| def load_model(): | |
| return AutoTokenizer.from_pretrained(model_name), AutoModelForSequenceClassification.from_pretrained('./') | |
| if 'tokenizer' not in globals(): | |
| tokenizer, model = load_model() | |
| title = st.text_area('Title') | |
| abstract = st.text_area('Abstract') | |
| label_to_topic_dict = dict(enumerate(['Computer Science', | |
| 'Economics', | |
| 'Electrical Engineering and Systems Science', | |
| 'Mathematics', | |
| 'Physics', | |
| 'Quantitative Biology', | |
| 'Quantitative Finance', | |
| 'Statistics'])) | |
| topic_to_label_dict = {label_to_topic_dict[key]: key for key in label_to_topic_dict.keys()} | |
| device='cuda:0' if torch.cuda.is_available() else 'cpu' | |
| def predict(title, abstract): | |
| d = {'title': [title], 'abstract': [abstract]} | |
| d = datasets.Dataset.from_dict(d) | |
| d = tokenizer(d["title"], d['abstract'], padding="max_length", truncation=True, return_tensors='pt') | |
| logits = model(input_ids=d['input_ids'].to(device), attention_mask=d['attention_mask'].to(device))['logits'] | |
| p = torch.nn.functional.softmax(logits)[0].cpu().detach() | |
| preds = [] | |
| proba = 0 | |
| for index in p.argsort(descending=True).tolist(): | |
| preds.append((label_to_topic_dict[index], p[index].item())) | |
| proba += p[index] | |
| if proba > .95: | |
| break | |
| return preds | |
| if len(title) == 0 and len(abstract) == 0: | |
| pass | |
| else: | |
| output = predict(title, abstract) | |
| st.text("Top 95% topics:") | |
| for topic, proba in output: | |
| st.text(f"{topic}: {proba*100:.0f}%") | |