tanquangduong's picture
:tada: add application files
b93e9c1
"""
@author: Tan Quang Duong
"""
import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report
# custom color map
norm = matplotlib.colors.Normalize(-1, 1)
colors = [[norm(-1.0), "#DAF7A6"], [norm(1.0), "#673FEE"]]
custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", colors)
def create_classification_report(y, y_pred):
target_class = ["negative", "positive"]
cls_report = classification_report(
y, y_pred, target_names=target_class, output_dict=True
)
df_report = pd.DataFrame(cls_report).transpose()
return df_report.round(2)
def get_100_random_test_review(df_test):
# get random 100 reviews
n_random = np.random.randint(len(df_test) - 101)
# get dataframe of 100 reviews
df_test_100 = df_test.iloc[n_random : n_random + 100]
# column rename
df_test_100 = df_test_100.rename(columns={"label": "class_id"})
return df_test_100
def inference_from_pytorch(text, tokenizer, model):
inputs = tokenizer(text, return_tensors="pt", truncation=True)
# do inference
with torch.no_grad():
logits = model(**inputs).logits
# get label
predicted_class_id = logits.argmax().item()
predicted_label = model.config.id2label[predicted_class_id]
return predicted_class_id, predicted_label
def plot_confusion_matric(confusion_matrix):
# annot=True to annotate cells, ftm='g' to disable scientific notation
sentiment_labels = ["Negative", "Positive"]
fig_cm, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(
confusion_matrix,
annot=True,
fmt="g",
cmap=custom_cmap,
ax=ax,
)
# labels, title and ticks
ax.set_xlabel("Predicted labels", size=12, weight="bold")
ax.set_ylabel("True labels", size=12, weight="bold")
ax.set_title("Confusion matrix", size=16, weight="bold")
ax.xaxis.set_ticklabels(sentiment_labels)
ax.yaxis.set_ticklabels(sentiment_labels)
return fig_cm
def plot_donut_sentiment_percentage(df):
# explosion
explode_val = (0.05, 0.05)
custom_colors = ["#673FEE", "#DAF7A6"]
# Give color names
fig_pie, ax_pie = plt.subplots()
ax_pie.pie(
df["count"],
labels=df["label"],
autopct="%1.1f%%",
pctdistance=0.5,
explode=explode_val,
colors=custom_colors,
)
ax_pie.set_title("Sentiment analysis", size=12, weight="bold")
# Create a circle at the center of the plot
my_circle = plt.Circle((0, 0), 0.7, color="white")
p = plt.gcf()
p.gca().add_artist(my_circle)
return fig_pie