Spaces:
Sleeping
Sleeping
import numpy as np | |
from load_data import * | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSequenceClassification | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli") | |
nli_model = ( | |
AutoModelForSequenceClassification.from_pretrained( | |
"facebook/bart-large-mnli" | |
).cuda() | |
if torch.cuda.is_available() | |
else AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli") | |
) | |
def get_prob(sequence, label): | |
premise = sequence | |
hypothesis = f"This example is {label}." | |
# run through model pre-trained on MNLI | |
x = tokenizer.encode( | |
premise, hypothesis, return_tensors="pt", truncation_strategy="only_first" | |
) | |
logits = nli_model(x.to(device))[0] | |
# we throw away "neutral" (dim 1) and take the probability of | |
# "entailment" (2) as the probability of the label being true | |
entail_contradiction_logits = logits[:, [0, 2]] | |
probs = entail_contradiction_logits.softmax(dim=1) | |
prob_label_is_true = probs[:, 1] | |
return prob_label_is_true[0].item() | |
def judge_mbti(sequence, labels): | |
out = [] | |
for l in labels: | |
temp = get_prob(sequence, l) | |
out.append((l, temp)) | |
out = sorted(out, key=lambda x: x[1], reverse=True) | |
return out | |
def compute_score(text, type): | |
x, y = type.split("_") | |
x_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][x])]) | |
y_score = np.sum([i[1] for i in judge_mbti(text, keywords_en[type][y])]) | |
if x_score > y_score: | |
choice = x | |
score = x_score | |
else: | |
choice = y | |
score = y_score | |
x_score_scaled = (x_score / (x_score + y_score)) * 100 | |
y_score_scaled = (y_score / (x_score + y_score)) * 100 | |
stat = {x: x_score_scaled, y: y_score_scaled} | |
return choice, stat | |
def mbti_translator(text): | |
E_I = compute_score(text, "E_I") | |
N_S = compute_score(text, "N_S") | |
T_F = compute_score(text, "T_F") | |
P_J = compute_score(text, "P_J") | |
return (E_I[0] + N_S[0] + T_F[0] + P_J[0]), (E_I[1], N_S[1], T_F[1], P_J[1]) | |
def plot_mbti(result): | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
start = 0 | |
x, y = result.values() | |
x_type, y_type = result.keys() | |
ax.broken_barh([(start, x), (x, x + y)], [10, 9], facecolors=("#FFC5BF", "#D4F0F0")) | |
ax.set_ylim(5, 15) | |
ax.set_xlim(0, 100) | |
ax.spines["left"].set_visible(False) | |
ax.spines["bottom"].set_visible(False) | |
ax.spines["top"].set_visible(False) | |
ax.spines["right"].set_visible(False) | |
ax.set_yticks([15, 25]) | |
ax.set_xticks([0, 25, 50, 75, 100]) | |
ax.text(x - 6, 14.5, x_type + " :" + str(int(x)) + "%", fontsize=15) | |
ax.text((x + y) - 6, 14.5, y_type + " :" + str(int(y)) + "%", fontsize=15) | |
st.pyplot(fig) | |