Chris Hoge commited on
Commit
36f0169
1 Parent(s): 8a1ffcd

Added sentiment analysis file

Browse files
Files changed (1) hide show
  1. sentiment_cnn.py +73 -0
sentiment_cnn.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SentimentCNN class based on Sentiment Analysis tutorial by Ben Trevett
2
+ # https://github.com/bentrevett/pytorch-sentiment-analysis
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchtext
7
+
8
+ class SentimentCNN(nn.Module):
9
+ def __init__(self, state_dict=None, vocab=None, tokenizer='basic_english'):
10
+ super().__init__()
11
+
12
+ # tokenizer setup
13
+ self.tokenizer = torchtext.data.utils.get_tokenizer(tokenizer)
14
+ self.state_dict_name = state_dict
15
+
16
+ if vocab:
17
+ self.load_vocab(vocab)
18
+
19
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ def _setup_model(self):
22
+ # cnn parameters
23
+ n_filters=100
24
+ filter_sizes=[3,5,7]
25
+ dropout_rate=0.25
26
+ self.min_length = max(filter_sizes)
27
+
28
+ # language space parameters
29
+ embedding_dim=300
30
+ output_dim=2
31
+
32
+ # model setup
33
+ self.embedding = nn.Embedding(
34
+ len(self.vocab),
35
+ embedding_dim,
36
+ padding_idx=self.pad_index)
37
+ self.convs = nn.ModuleList([nn.Conv1d(embedding_dim,
38
+ n_filters,
39
+ filter_size)
40
+ for filter_size in filter_sizes])
41
+ self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
42
+ self.dropout = nn.Dropout(dropout_rate)
43
+
44
+ if self.state_dict_name:
45
+ self.load_state_dict(torch.load(self.state_dict_name))
46
+
47
+ def load_vocab(self, vocab):
48
+ # vocabulary parameters
49
+ self.vocab = torch.load(vocab)
50
+ self.pad_index = self.vocab['<pad>']
51
+ self._setup_model()
52
+
53
+ def forward(self, ids):
54
+ embedded = self.dropout(self.embedding(ids))
55
+ embedded = embedded.permute(0,2,1)
56
+ conved = [torch.relu(conv(embedded)) for conv in self.convs]
57
+ pooled = [conv.max(dim=-1).values for conv in conved]
58
+ cat = self.dropout(torch.cat(pooled, dim=-1))
59
+ prediction = self.fc(cat)
60
+ return prediction
61
+
62
+ def predict_sentiment(self, text):
63
+ tokens = self.tokenizer(text)
64
+ ids = [self.vocab[t] for t in tokens]
65
+ if len(ids) < self.min_length:
66
+ ids += [self.pad_index] * (self.min_length - len(ids))
67
+ tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(self.device)
68
+ prediction = self(tensor).squeeze(dim=0)
69
+ probability = torch.softmax(prediction, dim=-1)
70
+ predicted_class = prediction.argmax(dim=-1).item()
71
+ predicted_probability = probability[predicted_class].item()
72
+
73
+ return predicted_class, predicted_probability