MishaD commited on
Commit
273d5d2
1 Parent(s): 3d0203f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +9 -0
  2. model.py +68 -0
  3. twitter_model_91_5-.pth +3 -0
app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from model import classify_text
3
+
4
+
5
+ st.markdown("### Sentiment classification (negative vs. positive)")
6
+ title = st.text_area("Your sentiment for classification", "I would not use it")
7
+ if st.button("Classify!"):
8
+ prob = float(classify_text(title))
9
+ st.markdown(f"**Model**: {round(prob, 5)}% of being positive")
model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import RobertaTokenizerFast, RobertaForMaskedLM
6
+ import streamlit as st
7
+
8
+
9
+ class SimpleClassifier(nn.Module):
10
+ def __init__(self, in_features: int, hidden_features: int,
11
+ out_features: int, activation=nn.ReLU()):
12
+ super().__init__()
13
+ self.bn = nn.BatchNorm1d(in_features)
14
+ self.in2hid = nn.Linear(in_features, hidden_features)
15
+ self.activation = activation
16
+ self.hid2hid = nn.Linear(hidden_features, hidden_features)
17
+ self.hid2out = nn.Linear(hidden_features, out_features)
18
+
19
+
20
+ #unused
21
+ self.bn2 = nn.BatchNorm1d(hidden_features)
22
+
23
+ def forward(self, X):
24
+ X = self.bn(X)
25
+ X = self.in2hid(X)
26
+
27
+ X = self.activation(X)
28
+ X = self.hid2hid(torch.concat((X,), 1))
29
+
30
+ X = self.activation(X)
31
+ X = self.hid2out(torch.concat((X,), 1))
32
+
33
+ X = nn.functional.sigmoid(X)
34
+ return X
35
+
36
+
37
+ @st.cache_data()
38
+ def load_models():
39
+ model = RobertaForMaskedLM.from_pretrained("roberta-base")
40
+ model.lm_head = nn.Identity()
41
+ tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
42
+ my_classifier = SimpleClassifier(768, 768, 1)
43
+ weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth")
44
+ my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
45
+ my_classifier.eval()
46
+ return {
47
+ "tokenizer": tokenizer,
48
+ "model": model,
49
+ "classifier": my_classifier
50
+ }
51
+
52
+
53
+ def classify_text(text: str) -> float:
54
+ models = load_models()
55
+ tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"]
56
+
57
+ X = tokenizer(
58
+ text,
59
+ truncation=True,
60
+ max_length=128,
61
+ return_tensors='pt'
62
+ )["input_ids"]
63
+
64
+ X = model.forward(X)[-1][0].sum(axis=0)[None, :]
65
+ return classifier(X)
66
+
67
+
68
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
twitter_model_91_5-.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3f12e3a20193609be0f8c8d6e2f5ca06e8a43e907b7d74dd108c3d64d77c5ad
3
+ size 4757645