Spaces:
Runtime error
Runtime error
""" | |
@author: Tan Quang Duong | |
""" | |
import torch | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import pandas as pd | |
from sklearn.metrics import classification_report | |
# custom color map | |
norm = matplotlib.colors.Normalize(-1, 1) | |
colors = [[norm(-1.0), "#DAF7A6"], [norm(1.0), "#673FEE"]] | |
custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", colors) | |
def create_classification_report(y, y_pred): | |
target_class = ["negative", "positive"] | |
cls_report = classification_report( | |
y, y_pred, target_names=target_class, output_dict=True | |
) | |
df_report = pd.DataFrame(cls_report).transpose() | |
return df_report.round(2) | |
def get_100_random_test_review(df_test): | |
# get random 100 reviews | |
n_random = np.random.randint(len(df_test) - 101) | |
# get dataframe of 100 reviews | |
df_test_100 = df_test.iloc[n_random : n_random + 100] | |
# column rename | |
df_test_100 = df_test_100.rename(columns={"label": "class_id"}) | |
return df_test_100 | |
def inference_from_pytorch(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
# do inference | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
# get label | |
predicted_class_id = logits.argmax().item() | |
predicted_label = model.config.id2label[predicted_class_id] | |
return predicted_class_id, predicted_label | |
def plot_confusion_matric(confusion_matrix): | |
# annot=True to annotate cells, ftm='g' to disable scientific notation | |
sentiment_labels = ["Negative", "Positive"] | |
fig_cm, ax = plt.subplots(figsize=(8, 8)) | |
sns.heatmap( | |
confusion_matrix, | |
annot=True, | |
fmt="g", | |
cmap=custom_cmap, | |
ax=ax, | |
) | |
# labels, title and ticks | |
ax.set_xlabel("Predicted labels", size=12, weight="bold") | |
ax.set_ylabel("True labels", size=12, weight="bold") | |
ax.set_title("Confusion matrix", size=16, weight="bold") | |
ax.xaxis.set_ticklabels(sentiment_labels) | |
ax.yaxis.set_ticklabels(sentiment_labels) | |
return fig_cm | |
def plot_donut_sentiment_percentage(df): | |
# explosion | |
explode_val = (0.05, 0.05) | |
custom_colors = ["#673FEE", "#DAF7A6"] | |
# Give color names | |
fig_pie, ax_pie = plt.subplots() | |
ax_pie.pie( | |
df["count"], | |
labels=df["label"], | |
autopct="%1.1f%%", | |
pctdistance=0.5, | |
explode=explode_val, | |
colors=custom_colors, | |
) | |
ax_pie.set_title("Sentiment analysis", size=12, weight="bold") | |
# Create a circle at the center of the plot | |
my_circle = plt.Circle((0, 0), 0.7, color="white") | |
p = plt.gcf() | |
p.gca().add_artist(my_circle) | |
return fig_pie | |