Hong commited on
Commit
7e9a171
1 Parent(s): b1026c7

Upload BART_utils.py

Browse files
Files changed (1) hide show
  1. BART_utils.py +101 -0
BART_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from load_data import *
3
+ import matplotlib.pyplot as plt
4
+ import streamlit as st
5
+ import torch
6
+
7
+ from transformers import AutoTokenizer
8
+ from transformers import AutoModelForSequenceClassification
9
+
10
+
11
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
14
+ nli_model = (
15
+ AutoModelForSequenceClassification.from_pretrained(
16
+ "facebook/bart-large-mnli"
17
+ ).cuda()
18
+ if torch.cuda.is_available()
19
+ else AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
20
+ )
21
+
22
+
23
+ def get_prob(sequence, label):
24
+ premise = sequence
25
+ hypothesis = f"This example is {label}."
26
+
27
+ # run through model pre-trained on MNLI
28
+ x = tokenizer.encode(
29
+ premise, hypothesis, return_tensors="pt", truncation_strategy="only_first"
30
+ )
31
+ logits = nli_model(x.to(device))[0]
32
+
33
+ # we throw away "neutral" (dim 1) and take the probability of
34
+ # "entailment" (2) as the probability of the label being true
35
+ entail_contradiction_logits = logits[:, [0, 2]]
36
+ probs = entail_contradiction_logits.softmax(dim=1)
37
+ prob_label_is_true = probs[:, 1]
38
+ return prob_label_is_true[0].item()
39
+
40
+
41
+ def judge_mbti(sequence, labels):
42
+ out = []
43
+ for l in labels:
44
+ temp = get_prob(sequence, l)
45
+ out.append((l, temp))
46
+ out = sorted(out, key=lambda x: x[1], reverse=True)
47
+ return out
48
+
49
+
50
+ def compute_score(text, type):
51
+ x, y = type.split("_")
52
+ x_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][x])])
53
+ y_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][y])])
54
+
55
+ if x_score > y_score:
56
+ choice = x
57
+ score = x_score
58
+ else:
59
+ choice = y
60
+ score = y_score
61
+
62
+ x_score_scaled = (x_score / (x_score + y_score)) * 100
63
+ y_score_scaled = (y_score / (x_score + y_score)) * 100
64
+
65
+ stat = {x: x_score_scaled, y: y_score_scaled}
66
+
67
+ return choice, stat
68
+
69
+
70
+ def mbti_translator(text):
71
+ E_I = compute_score(text, "E_I")
72
+ N_S = compute_score(text, "N_S")
73
+ T_F = compute_score(text, "T_F")
74
+ P_J = compute_score(text, "P_J")
75
+
76
+ return (E_I[0] + N_S[0] + T_F[0] + P_J[0]), (E_I[1], N_S[1], T_F[1], P_J[1])
77
+
78
+
79
+ def plot_mbti(result):
80
+ fig, ax = plt.subplots(figsize=(10, 5))
81
+
82
+ start = 0
83
+ x, y = result.values()
84
+ x_type, y_type = result.keys()
85
+
86
+ ax.broken_barh([(start, x), (x, x + y)], [10, 9],
87
+ facecolors=("#FFC5BF", "#D4F0F0"))
88
+ ax.set_ylim(5, 15)
89
+ ax.set_xlim(0, 100)
90
+ ax.spines["left"].set_visible(False)
91
+ ax.spines["bottom"].set_visible(False)
92
+ ax.spines["top"].set_visible(False)
93
+ ax.spines["right"].set_visible(False)
94
+
95
+ ax.set_yticks([15, 25])
96
+ ax.set_xticks([0, 25, 50, 75, 100])
97
+
98
+ ax.text(x - 6, 14.5, x_type + " :" + str(int(x)) + "%", fontsize=15)
99
+ ax.text((x + y) - 6, 14.5, y_type + " :" + str(int(y)) + "%", fontsize=15)
100
+
101
+ st.pyplot(fig)