Laronix_Recording / local /wer_plot_report.py
KevinGeng's picture
push to HF
a1fe393
raw
history blame
1.55 kB
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import pdb
threshold = 0.3
if __name__ == "__main__":
wer_csv = sys.argv[1]
df = pd.read_csv(wer_csv)
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(25, 15))
# Hist for distribution
ax[0].set_xlabel("Word Error Rate")
ax[0].set_ylabel("Counts")
ax[0].set_xlim(left=0.0, right=df['wer'].max())
ax[0].hist(df['wer'], bins=50)
ax[0].axvline(x=threshold, color="r")
# plt.savefig("hist.png")
# Line curve for each sentences
colors = ['green' if x < threshold else 'red' for x in df['wer']]
new_ids = [str(x).split('.')[0] for x in df['id']]
ax[1].set_xlabel("IDs")
ax[1].set_ylabel("Word Error Rate")
ax[1].scatter(new_ids, df['wer'], c=colors, marker='o')
ax[1].vlines(new_ids, ymin=0, ymax=df['wer'], colors='grey', linestyle='dotted', label='Vertical Lines')
ax[1].axhline(y=threshold, xmin=0, xmax=len(new_ids), color='r')
# ax[0].axhline(y=threshold, color="black")
# for i, v in enumerate(df['wer']):
# plt.text(str(df['id'][i]).split('.')[0], -2, str(df['id'][i]), ha='center', fontsize=3)
ax[1].set_xticklabels(new_ids, rotation=90, fontsize=10)
ax[1].tick_params(axis='x', width=20)
# ax[1].set_xlim(10, len(df['id']) + 10)
plt.tight_layout()
pdb.set_trace()
# fig.savefig("%s/%s.png"%(Path(sys.argv[1]).parent, sys.argv[1].split('/')[-1]), format='png')
fig.savefig("%s.png"%(sys.argv[1]), format='png')