kornosk commited on
Commit
30782ff
1 Parent(s): 114a362

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +92 -0
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pre-trained BERT on Twitter US Election 2020 for Stance Detection towards Joe Biden (KE-MLM)
2
+
3
+ Pre-trained weights for **KE-MLM model** in [Knowledge Enhance Masked Language Model for Stance Detection](https://2021.naacl.org/program/accepted/), NAACL 2021.
4
+
5
+ # Training Data
6
+
7
+ This model is pre-trained on over 5 million English tweets about the 2020 US Presidential Election. Then fine-tuned using our [stance-labeled data](https://github.com/GU-DataLab/stance-detection-KE-MLM) for stance detection towards Joe Biden.
8
+
9
+ # Training Objective
10
+
11
+ This model is initialized with BERT-base and trained with normal MLM objective with classification layer fine-tuned for stance detection towards Joe Biden.
12
+
13
+ # Usage
14
+
15
+ This pre-trained language model is fine-tuned to the stance detection task specifically for Joe Biden.
16
+
17
+ Please see the [official repository](https://github.com/GU-DataLab/stance-detection-KE-MLM) for more detail.
18
+
19
+ ```python
20
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
21
+ import torch
22
+ import numpy as np
23
+
24
+ # choose GPU if available
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # select mode path here
28
+ pretrained_LM_path = "kornosk/bert-election2020-twitter-stance-biden-KE-MLM"
29
+
30
+ # load model
31
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_LM_path)
32
+ model = AutoModelForSequenceClassification.from_pretrained(pretrained_LM_path)
33
+
34
+ id2label = {
35
+ 0: "AGAINST",
36
+ 1: "FAVOR",
37
+ 2: "NONE"
38
+ }
39
+
40
+ ##### Prediction Neutral #####
41
+ sentence = "Hello World."
42
+ inputs = tokenizer(sentence.lower(), return_tensors="pt")
43
+ outputs = model(**inputs)
44
+ predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist()
45
+
46
+ print("Sentence:", sentence)
47
+ print("Prediction:", id2label[np.argmax(predicted_probability)])
48
+ print("Against:", predicted_probability[0])
49
+ print("Favor:", predicted_probability[1])
50
+ print("Neutral:", predicted_probability[2])
51
+
52
+ ##### Prediction Favor #####
53
+ sentence = "Go Go Biden!!!"
54
+ inputs = tokenizer(sentence.lower(), return_tensors="pt")
55
+ outputs = model(**inputs)
56
+ predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist()
57
+
58
+ print("Sentence:", sentence)
59
+ print("Prediction:", id2label[np.argmax(predicted_probability)])
60
+ print("Against:", predicted_probability[0])
61
+ print("Favor:", predicted_probability[1])
62
+ print("Neutral:", predicted_probability[2])
63
+
64
+ ##### Prediction Against #####
65
+ sentence = "Biden is the worst."
66
+ inputs = tokenizer(sentence.lower(), return_tensors="pt")
67
+ outputs = model(**inputs)
68
+ predicted_probability = torch.softmax(outputs[0], dim=1)[0].tolist()
69
+
70
+ print("Sentence:", sentence)
71
+ print("Prediction:", id2label[np.argmax(predicted_probability)])
72
+ print("Against:", predicted_probability[0])
73
+ print("Favor:", predicted_probability[1])
74
+ print("Neutral:", predicted_probability[2])
75
+
76
+ # please consider citing our paper if you feel this is useful :)
77
+ ```
78
+
79
+ # Reference
80
+
81
+ - [Knowledge Enhance Masked Language Model for Stance Detection](https://2021.naacl.org/program/accepted/), NAACL 2021.
82
+
83
+ # Citation
84
+ ```bibtex
85
+ @inproceedings{kawintiranon2021knowledge,
86
+ title={Knowledge Enhanced Masked Language Model for Stance Detection},
87
+ author={Kawintiranon, Kornraphop and Singh, Lisa},
88
+ booktitle={Proceedings of the 2021 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL)},
89
+ year={2021},
90
+ url={#}
91
+ }
92
+ ```