MBTI_translator / BART_utils.py
Hong
Upload BART_utils.py
7e9a171
raw history blame
No virus
2.91 kB
import numpy as np
from load_data import *
import matplotlib.pyplot as plt
import streamlit as st
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
nli_model = (
AutoModelForSequenceClassification.from_pretrained(
"facebook/bart-large-mnli"
).cuda()
if torch.cuda.is_available()
else AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
)
def get_prob(sequence, label):
premise = sequence
hypothesis = f"This example is {label}."
# run through model pre-trained on MNLI
x = tokenizer.encode(
premise, hypothesis, return_tensors="pt", truncation_strategy="only_first"
)
logits = nli_model(x.to(device))[0]
# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true
entail_contradiction_logits = logits[:, [0, 2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:, 1]
return prob_label_is_true[0].item()
def judge_mbti(sequence, labels):
out = []
for l in labels:
temp = get_prob(sequence, l)
out.append((l, temp))
out = sorted(out, key=lambda x: x[1], reverse=True)
return out
def compute_score(text, type):
x, y = type.split("_")
x_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][x])])
y_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][y])])
if x_score > y_score:
choice = x
score = x_score
else:
choice = y
score = y_score
x_score_scaled = (x_score / (x_score + y_score)) * 100
y_score_scaled = (y_score / (x_score + y_score)) * 100
stat = {x: x_score_scaled, y: y_score_scaled}
return choice, stat
def mbti_translator(text):
E_I = compute_score(text, "E_I")
N_S = compute_score(text, "N_S")
T_F = compute_score(text, "T_F")
P_J = compute_score(text, "P_J")
return (E_I[0] + N_S[0] + T_F[0] + P_J[0]), (E_I[1], N_S[1], T_F[1], P_J[1])
def plot_mbti(result):
fig, ax = plt.subplots(figsize=(10, 5))
start = 0
x, y = result.values()
x_type, y_type = result.keys()
ax.broken_barh([(start, x), (x, x + y)], [10, 9],
facecolors=("#FFC5BF", "#D4F0F0"))
ax.set_ylim(5, 15)
ax.set_xlim(0, 100)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_yticks([15, 25])
ax.set_xticks([0, 25, 50, 75, 100])
ax.text(x - 6, 14.5, x_type + " :" + str(int(x)) + "%", fontsize=15)
ax.text((x + y) - 6, 14.5, y_type + " :" + str(int(y)) + "%", fontsize=15)
st.pyplot(fig)