File size: 1,966 Bytes
273d5d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os.path

import torch
import torch.nn as nn
from transformers import RobertaTokenizerFast, RobertaForMaskedLM
import streamlit as st


class SimpleClassifier(nn.Module):
    def __init__(self, in_features: int, hidden_features: int,
                 out_features: int, activation=nn.ReLU()):
        super().__init__()
        self.bn = nn.BatchNorm1d(in_features)
        self.in2hid = nn.Linear(in_features, hidden_features)
        self.activation = activation
        self.hid2hid = nn.Linear(hidden_features, hidden_features)
        self.hid2out = nn.Linear(hidden_features, out_features)


        #unused
        self.bn2 = nn.BatchNorm1d(hidden_features)

    def forward(self, X):
        X = self.bn(X)
        X = self.in2hid(X)

        X = self.activation(X)
        X = self.hid2hid(torch.concat((X,), 1))

        X = self.activation(X)
        X = self.hid2out(torch.concat((X,), 1))

        X = nn.functional.sigmoid(X)
        return X


@st.cache_data()
def load_models():
    model = RobertaForMaskedLM.from_pretrained("roberta-base")
    model.lm_head = nn.Identity()
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    my_classifier = SimpleClassifier(768, 768, 1)
    weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth")
    my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
    my_classifier.eval()
    return {
        "tokenizer": tokenizer,
        "model": model,
        "classifier": my_classifier
    }


def classify_text(text: str) -> float:
    models = load_models()
    tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"]

    X = tokenizer(
        text,
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )["input_ids"]

    X = model.forward(X)[-1][0].sum(axis=0)[None, :]
    return classifier(X)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")