MishaD's picture
Upload 3 files
273d5d2
raw
history blame contribute delete
No virus
1.97 kB
import os.path
import torch
import torch.nn as nn
from transformers import RobertaTokenizerFast, RobertaForMaskedLM
import streamlit as st
class SimpleClassifier(nn.Module):
def __init__(self, in_features: int, hidden_features: int,
out_features: int, activation=nn.ReLU()):
super().__init__()
self.bn = nn.BatchNorm1d(in_features)
self.in2hid = nn.Linear(in_features, hidden_features)
self.activation = activation
self.hid2hid = nn.Linear(hidden_features, hidden_features)
self.hid2out = nn.Linear(hidden_features, out_features)
#unused
self.bn2 = nn.BatchNorm1d(hidden_features)
def forward(self, X):
X = self.bn(X)
X = self.in2hid(X)
X = self.activation(X)
X = self.hid2hid(torch.concat((X,), 1))
X = self.activation(X)
X = self.hid2out(torch.concat((X,), 1))
X = nn.functional.sigmoid(X)
return X
@st.cache_data()
def load_models():
model = RobertaForMaskedLM.from_pretrained("roberta-base")
model.lm_head = nn.Identity()
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
my_classifier = SimpleClassifier(768, 768, 1)
weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth")
my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
my_classifier.eval()
return {
"tokenizer": tokenizer,
"model": model,
"classifier": my_classifier
}
def classify_text(text: str) -> float:
models = load_models()
tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"]
X = tokenizer(
text,
truncation=True,
max_length=128,
return_tensors='pt'
)["input_ids"]
X = model.forward(X)[-1][0].sum(axis=0)[None, :]
return classifier(X)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")