Ptato commited on
Commit
3bc73f1
·
1 Parent(s): 30ec224
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -20,9 +20,15 @@ if 'id2label' not in st.session_state:
20
  st.session_state.id2label = {idx: label for idx, label in enumerate(st.session_state.labels)}
21
  if 'filled' not in st.session_state:
22
  st.session_state.filled = False
 
 
 
 
 
23
 
24
  form = st.form(key='Sentiment Analysis')
25
- st.session_state.options = ['bertweet-base-sentiment-analysis',
 
26
  'distilbert-base-uncased-finetuned-sst-2-english',
27
  'twitter-roberta-base-sentiment',
28
  'Modified Bert Toxicity Classification'
@@ -43,6 +49,10 @@ if not st.session_state.filled:
43
  text = st.session_state.df["comment_text"].iloc[x][:128]
44
  for s in st.session_state.options:
45
  pline = None
 
 
 
 
46
  if s == 'bertweet-base-sentiment-analysis':
47
  pline = pipeline(task="sentiment-analysis", model="finiteautomata/bertweet-base-sentiment-analysis")
48
  elif s == 'twitter-roberta-base-sentiment':
@@ -50,16 +60,13 @@ if not st.session_state.filled:
50
  elif s == 'distilbert-base-uncased-finetuned-sst-2-english':
51
  pline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
52
  else:
53
- model = AutoModelForSequenceClassification.from_pretrained("Ptato/Modified-Bert-Toxicity-Classification")
54
- model.eval()
55
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
56
- encoding = tokenizer(tweet, return_tensors="pt")
57
- encoding = {k: v.to(model.device) for k,v in encoding.items()}
58
- predictions = model(**encoding)
59
  logits = predictions.logits
60
  sigmoid = torch.nn.Sigmoid()
61
  probs = sigmoid(logits.squeeze().cpu())
62
- predicts = np.zeros(probs.shape)
63
  predictions[np.where(probs >= 0.5)] = 1
64
  predicted_labels = [st.session_state.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
65
  log = []
@@ -106,7 +113,7 @@ if not st.session_state.filled:
106
  else:
107
  log = [0] * 6
108
  log[1] = text
109
- if max(predicts) == 0:
110
  log[0] = 0
111
  log[2] = ("NO TOXICITY")
112
  log[3] = (f"{100 - round(probs[0].item() * 100, 1)}%")
@@ -116,7 +123,7 @@ if not st.session_state.filled:
116
  log[0] = 1
117
  _max = 0
118
  _max2 = 2
119
- for i in range(1, len(predicts)):
120
  if probs[i].item() > probs[_max].item():
121
  _max = i
122
  if i > 2 and probs[i].item() > probs[_max2].item():
@@ -144,11 +151,9 @@ if submit and tweet:
144
  elif box == 'distilbert-base-uncased-finetuned-sst-2-english':
145
  pline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
146
  else:
147
- model = AutoModelForSequenceClassification.from_pretrained("Ptato/Modified-Bert-Toxicity-Classification")
148
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
149
- encoding = tokenizer(tweet, return_tensors="pt")
150
- encoding = {k: v.to(model.device) for k,v in encoding.items()}
151
- predictions = model(**encoding)
152
  logits = predictions.logits
153
  sigmoid = torch.nn.Sigmoid()
154
  probs = sigmoid(logits.squeeze().cpu())
 
20
  st.session_state.id2label = {idx: label for idx, label in enumerate(st.session_state.labels)}
21
  if 'filled' not in st.session_state:
22
  st.session_state.filled = False
23
+ if 'model' not in st.session_state:
24
+ st.session_state.model = AutoModelForSequenceClassification.from_pretrained("Ptato/Modified-Bert-Toxicity-Classification")
25
+ st.session_state.model.eval()
26
+ if 'tokenizer' not in st.session_state:
27
+ st.session_state.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
28
 
29
  form = st.form(key='Sentiment Analysis')
30
+ st.session_state.options = [
31
+ 'bertweet-base-sentiment-analysis',
32
  'distilbert-base-uncased-finetuned-sst-2-english',
33
  'twitter-roberta-base-sentiment',
34
  'Modified Bert Toxicity Classification'
 
49
  text = st.session_state.df["comment_text"].iloc[x][:128]
50
  for s in st.session_state.options:
51
  pline = None
52
+ predictions = None
53
+ encoding = None
54
+ logits = None
55
+ probs = None
56
  if s == 'bertweet-base-sentiment-analysis':
57
  pline = pipeline(task="sentiment-analysis", model="finiteautomata/bertweet-base-sentiment-analysis")
58
  elif s == 'twitter-roberta-base-sentiment':
 
60
  elif s == 'distilbert-base-uncased-finetuned-sst-2-english':
61
  pline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
62
  else:
63
+ encoding = st.session_state.tokenizer(text, return_tensors="pt")
64
+ encoding = {k: v.to(st.session_state.model.device) for k, v in encoding.items()}
65
+ predictions = st.session_state.model(**encoding)
 
 
 
66
  logits = predictions.logits
67
  sigmoid = torch.nn.Sigmoid()
68
  probs = sigmoid(logits.squeeze().cpu())
69
+ predictions = np.zeros(probs.shape)
70
  predictions[np.where(probs >= 0.5)] = 1
71
  predicted_labels = [st.session_state.id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
72
  log = []
 
113
  else:
114
  log = [0] * 6
115
  log[1] = text
116
+ if max(predictions) == 0:
117
  log[0] = 0
118
  log[2] = ("NO TOXICITY")
119
  log[3] = (f"{100 - round(probs[0].item() * 100, 1)}%")
 
123
  log[0] = 1
124
  _max = 0
125
  _max2 = 2
126
+ for i in range(1, len(predictions)):
127
  if probs[i].item() > probs[_max].item():
128
  _max = i
129
  if i > 2 and probs[i].item() > probs[_max2].item():
 
151
  elif box == 'distilbert-base-uncased-finetuned-sst-2-english':
152
  pline = pipeline(task="sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
153
  else:
154
+ encoding = st.session_state.tokenizer(tweet, return_tensors="pt")
155
+ encoding = {k: v.to(st.session_state.model.device) for k,v in encoding.items()}
156
+ predictions = st.session_state.model(**encoding)
 
 
157
  logits = predictions.logits
158
  sigmoid = torch.nn.Sigmoid()
159
  probs = sigmoid(logits.squeeze().cpu())