Spaces:
Sleeping
Sleeping
Upload predictor.py
Browse files- predictor.py +96 -0
predictor.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import time
|
3 |
+
from pprint import pprint
|
4 |
+
import Pinpoint_Internal.FeatureExtraction
|
5 |
+
from Pinpoint_Internal.RandomForest import *
|
6 |
+
|
7 |
+
class predictor():
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
self.model = random_forest()
|
11 |
+
self.model.PSYCHOLOGICAL_SIGNALS_ENABLED = False # Needs LIWC markup
|
12 |
+
self.model.BEHAVIOURAL_FEATURES_ENABLED = False
|
13 |
+
self.model.train_model(features_file=None, force_new_dataset=False,
|
14 |
+
model_location=r"far-right-radical-language.model")
|
15 |
+
self.dict_of_users_all = {}
|
16 |
+
self.feature_extractor = Pinpoint_Internal.FeatureExtraction.feature_extraction(
|
17 |
+
violent_words_dataset_location="swears",
|
18 |
+
baseline_training_dataset_location="LIWC2015 Results (Storm_Front_Posts).csv")
|
19 |
+
|
20 |
+
def predict(self, string_to_predict):
|
21 |
+
self.__init__()
|
22 |
+
try:
|
23 |
+
os.remove("./messages.json")
|
24 |
+
except:
|
25 |
+
pass
|
26 |
+
try:
|
27 |
+
os.remove("messages.json")
|
28 |
+
except:
|
29 |
+
pass
|
30 |
+
|
31 |
+
try:
|
32 |
+
os.remove("./all-messages.csv")
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
|
36 |
+
users_posts = [{"username": "tmp", "timestamp": "tmp", "message": "{}".format(string_to_predict)}]
|
37 |
+
|
38 |
+
with open('all-messages.csv', 'w', encoding='utf8', newline='') as output_file:
|
39 |
+
writer = csv.DictWriter(output_file, fieldnames=["username", "timestamp", "message"])
|
40 |
+
for users_post in users_posts:
|
41 |
+
writer.writerow(users_post)
|
42 |
+
|
43 |
+
self.feature_extractor._get_standard_tweets("all-messages.csv")
|
44 |
+
|
45 |
+
|
46 |
+
with open("./messages.json", 'w') as outfile:
|
47 |
+
features = self.feature_extractor.completed_tweet_user_features
|
48 |
+
|
49 |
+
json.dump(features, outfile, indent=4)
|
50 |
+
|
51 |
+
rows = self.model.get_features_as_df("./messages.json", True)
|
52 |
+
rows.pop("is_extremist")
|
53 |
+
|
54 |
+
iter = 0
|
55 |
+
|
56 |
+
message_vector_list = []
|
57 |
+
|
58 |
+
for user_iter in range(0, len(users_posts)):
|
59 |
+
rows_as_json = json.loads(rows.iloc[iter].to_json())
|
60 |
+
|
61 |
+
tmp = []
|
62 |
+
for i in range(1, 201):
|
63 |
+
vect_str = "message_vector_{}".format(str(i))
|
64 |
+
vector = rows_as_json[vect_str]
|
65 |
+
tmp.append(vector)
|
66 |
+
message_vector_list.append(tmp)
|
67 |
+
|
68 |
+
iter = iter + 1
|
69 |
+
|
70 |
+
for row in users_posts:
|
71 |
+
user = row["username"]
|
72 |
+
timestamp = row["timestamp"]
|
73 |
+
message = row["message"]
|
74 |
+
user_unique_id = str(self.feature_extractor._get_unique_id_from_username(user))
|
75 |
+
|
76 |
+
iter = 0
|
77 |
+
user_found = False
|
78 |
+
while not user_found:
|
79 |
+
try:
|
80 |
+
user_features = self.feature_extractor.completed_tweet_user_features[iter][user_unique_id]
|
81 |
+
user_found = True
|
82 |
+
break
|
83 |
+
except KeyError as e:
|
84 |
+
iter = iter + 1
|
85 |
+
|
86 |
+
formated_vectors = [float('%.10f' % elem) for elem in user_features["message_vector"]]
|
87 |
+
iter = 0
|
88 |
+
for vector_list in message_vector_list:
|
89 |
+
|
90 |
+
if message_vector_list[iter] == formated_vectors:
|
91 |
+
is_extremist = self.model.model.predict([rows.iloc[iter]])
|
92 |
+
|
93 |
+
if is_extremist == 1:
|
94 |
+
return True
|
95 |
+
else:
|
96 |
+
return False
|