import os.path import torch import torch.nn as nn from transformers import RobertaTokenizerFast, RobertaForMaskedLM import streamlit as st class SimpleClassifier(nn.Module): def __init__(self, in_features: int, hidden_features: int, out_features: int, activation=nn.ReLU()): super().__init__() self.bn = nn.BatchNorm1d(in_features) self.in2hid = nn.Linear(in_features, hidden_features) self.activation = activation self.hid2hid = nn.Linear(hidden_features, hidden_features) self.hid2out = nn.Linear(hidden_features, out_features) #unused self.bn2 = nn.BatchNorm1d(hidden_features) def forward(self, X): X = self.bn(X) X = self.in2hid(X) X = self.activation(X) X = self.hid2hid(torch.concat((X,), 1)) X = self.activation(X) X = self.hid2out(torch.concat((X,), 1)) X = nn.functional.sigmoid(X) return X @st.cache_data() def load_models(): model = RobertaForMaskedLM.from_pretrained("roberta-base") model.lm_head = nn.Identity() tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") my_classifier = SimpleClassifier(768, 768, 1) weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth") my_classifier.load_state_dict(torch.load(weights_path, map_location=device)) my_classifier.eval() return { "tokenizer": tokenizer, "model": model, "classifier": my_classifier } def classify_text(text: str) -> float: models = load_models() tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"] X = tokenizer( text, truncation=True, max_length=128, return_tensors='pt' )["input_ids"] X = model.forward(X)[-1][0].sum(axis=0)[None, :] return classifier(X) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")