import streamlit as st import warnings warnings.simplefilter('ignore') import numpy as np import pandas as pd from tqdm import tqdm from sklearn import metrics import transformers import torch import json import pandas as pd from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler from sklearn.model_selection import train_test_split from transformers import DistilBertTokenizer, DistilBertModel import logging logging.basicConfig(level=logging.ERROR) from torch import cuda device = 'cuda' if cuda.is_available() else 'cpu' st.markdown("## arXiv classificator") # st.markdown("", unsafe_allow_html=True) st.markdown("Please type the article's title and abstract below") title = st.text_input("Title") abstract = st.text_input("Abstract") def is_good(tag: str) -> bool: return "stat." in tag\ or "cs." in tag\ or "math." in tag\ or "ph." in tag\ or "fin." in tag\ or "bio." in tag\ or "eess." in tag\ or "econ." in tag def get_all_tags(tag_str: str): tag_json = tag_str.replace("'", '"').replace("None", '"None"') return [elem["term"] for elem in json.loads(tag_json) if is_good(elem["term"])] def join_title_and_summary(row) -> str: return row["title"].replace("\n", " ") + " " + row["summary"].replace("\n", " ") class MultiLabelDataset(Dataset): def __init__(self, dataframe, tokenizer, max_len): self.tokenizer = tokenizer self.data = dataframe self.text = self.data["Text"] self.targets = self.data["Labels"] self.max_len = max_len def __len__(self): return len(self.text) def __getitem__(self, index): text = str(self.text[index]) text = " ".join(text.split()) inputs = self.tokenizer.encode_plus( text, truncation=True, add_special_tokens=True, max_length=self.max_len, pad_to_max_length=True, return_token_type_ids=True ) ids = inputs['input_ids'] mask = inputs['attention_mask'] token_type_ids = inputs["token_type_ids"] return { 'ids': torch.tensor(ids, dtype=torch.long), 'mask': torch.tensor(mask, dtype=torch.long), 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), 'targets': torch.tensor(self.targets[index], dtype=torch.float) } class DistilBERTClass(torch.nn.Module): def __init__(self): super(DistilBERTClass, self).__init__() self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.1) self.classifier = torch.nn.Linear(768, 124) def forward(self, input_ids, attention_mask, token_type_ids): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.pre_classifier(pooler) pooler = torch.nn.Tanh()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) return output def loss_fn(outputs, targets): return torch.nn.BCEWithLogitsLoss()(outputs, targets) @st.cache def prepare_model(): with open("./arxivData.json", 'r') as fp: data = json.load(fp) data = pd.DataFrame(data) data.drop(['id', "month", "author", "day", "year", "link"], inplace=True, axis=1) labels = data["tag"].map(get_all_tags) good_tags = set() for tags in labels: for tag in tags: good_tags.add(tag) enum_tags = dict() enum_tags_reverse = [None for _ in range(len(good_tags))] for idx, tag in enumerate(good_tags): enum_tags[tag] = idx enum_tags_reverse[idx] = tag def map_tags_to_target_vector(tags): target_vector = [0.0] * len(enum_tags) for tag in tags: idx = enum_tags[tag] target_vector[idx] = 1.0 / len(tags) assert np.allclose(np.sum(target_vector), 1.0, 0.000001) return target_vector vectors = labels.map(map_tags_to_target_vector) texts = data.apply(join_title_and_summary, axis=1) preprocessed_data = pd.DataFrame({ "Labels": vectors, "Text": texts }) MAX_LEN = 512 TRAIN_BATCH_SIZE = 4 VALID_BATCH_SIZE = 4 EPOCHS = 1 LEARNING_RATE = 1e-05 tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, to_lower_case=True) train_data, test_data = train_test_split(preprocessed_data, train_size=0.8) train_data.reset_index(drop=True, inplace=True) test_data.reset_index(drop=True, inplace=True) training_set = MultiLabelDataset(train_data, tokenizer, MAX_LEN) testing_set = MultiLabelDataset(test_data, tokenizer, MAX_LEN) train_params = {'batch_size': TRAIN_BATCH_SIZE, 'shuffle': True, 'num_workers': 0 } test_params = {'batch_size': VALID_BATCH_SIZE, 'shuffle': True, 'num_workers': 0 } training_loader = DataLoader(training_set, **train_params) testing_loader = DataLoader(testing_set, **test_params) model = DistilBERTClass() model.to(device) optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE) def train(epoch): model.train() for _, data in tqdm(enumerate(training_loader, 0)): ids = data['ids'].to(device, dtype = torch.long) mask = data['mask'].to(device, dtype = torch.long) token_type_ids = data['token_type_ids'].to(device, dtype = torch.long) targets = data['targets'].to(device, dtype = torch.float) outputs = model(ids, mask, token_type_ids) optimizer.zero_grad() loss = loss_fn(outputs, targets) if _ % 100==0: print(f'Epoch: {epoch}, Loss: {loss.item()}') loss.backward() optimizer.step() for epoch in range(EPOCHS): train(epoch) def predict(text, abstract): text += " " + abstract text = " ".join(text.split()) inputs = tokenizer.encode_plus( text, truncation=True, add_special_tokens=True, max_length=MAX_LEN, pad_to_max_length=True, return_token_type_ids=True ) ids = torch.tensor(inputs['input_ids'], dtype=torch.long).to(device, dtype = torch.long) mask = torch.tensor(inputs['attention_mask'], dtype=torch.long).to(device, dtype = torch.long) token_type_ids = torch.tensor(inputs["token_type_ids"], dtype=torch.long).to(device, dtype = torch.long) with torch.no_grad(): logits = model(ids, attention_mask=mask, token_type_ids=token_type_ids) argmax = logits.cpu().detach().numpy().argmax() return enum_tags_reverse[argmax] return predict predict_function = prepare_model() try: raw_predictions = predict_function(title, abstract) st.markdown(f"The most likely arXiv tag for this article is:")# {raw_predictions}") if raw_predictions: for item in raw_predictions: st.markdown(f"* {item}") elif (title or abstract): st.markdown("* cs.CV") else: st.markdown("Oops... your input is empty") except: st.markdown("Oops... something went wrong")