""" @author: Tan Quang Duong """ import torch import matplotlib import matplotlib.pyplot as plt import seaborn as sns import numpy as np import pandas as pd from sklearn.metrics import classification_report # custom color map norm = matplotlib.colors.Normalize(-1, 1) colors = [[norm(-1.0), "#DAF7A6"], [norm(1.0), "#673FEE"]] custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", colors) def create_classification_report(y, y_pred): target_class = ["negative", "positive"] cls_report = classification_report( y, y_pred, target_names=target_class, output_dict=True ) df_report = pd.DataFrame(cls_report).transpose() return df_report.round(2) def get_100_random_test_review(df_test): # get random 100 reviews n_random = np.random.randint(len(df_test) - 101) # get dataframe of 100 reviews df_test_100 = df_test.iloc[n_random : n_random + 100] # column rename df_test_100 = df_test_100.rename(columns={"label": "class_id"}) return df_test_100 def inference_from_pytorch(text, tokenizer, model): inputs = tokenizer(text, return_tensors="pt", truncation=True) # do inference with torch.no_grad(): logits = model(**inputs).logits # get label predicted_class_id = logits.argmax().item() predicted_label = model.config.id2label[predicted_class_id] return predicted_class_id, predicted_label def plot_confusion_matric(confusion_matrix): # annot=True to annotate cells, ftm='g' to disable scientific notation sentiment_labels = ["Negative", "Positive"] fig_cm, ax = plt.subplots(figsize=(8, 8)) sns.heatmap( confusion_matrix, annot=True, fmt="g", cmap=custom_cmap, ax=ax, ) # labels, title and ticks ax.set_xlabel("Predicted labels", size=12, weight="bold") ax.set_ylabel("True labels", size=12, weight="bold") ax.set_title("Confusion matrix", size=16, weight="bold") ax.xaxis.set_ticklabels(sentiment_labels) ax.yaxis.set_ticklabels(sentiment_labels) return fig_cm def plot_donut_sentiment_percentage(df): # explosion explode_val = (0.05, 0.05) custom_colors = ["#673FEE", "#DAF7A6"] # Give color names fig_pie, ax_pie = plt.subplots() ax_pie.pie( df["count"], labels=df["label"], autopct="%1.1f%%", pctdistance=0.5, explode=explode_val, colors=custom_colors, ) ax_pie.set_title("Sentiment analysis", size=12, weight="bold") # Create a circle at the center of the plot my_circle = plt.Circle((0, 0), 0.7, color="white") p = plt.gcf() p.gca().add_artist(my_circle) return fig_pie