Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import sys | |
import pdb | |
threshold = 0.3 | |
if __name__ == "__main__": | |
wer_csv = sys.argv[1] | |
df = pd.read_csv(wer_csv) | |
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(25, 15)) | |
# Hist for distribution | |
ax[0].set_xlabel("Word Error Rate") | |
ax[0].set_ylabel("Counts") | |
ax[0].set_xlim(left=0.0, right=df['wer'].max()) | |
ax[0].hist(df['wer'], bins=50) | |
ax[0].axvline(x=threshold, color="r") | |
# plt.savefig("hist.png") | |
# Line curve for each sentences | |
colors = ['green' if x < threshold else 'red' for x in df['wer']] | |
new_ids = [str(x).split('.')[0] for x in df['id']] | |
ax[1].set_xlabel("IDs") | |
ax[1].set_ylabel("Word Error Rate") | |
ax[1].scatter(new_ids, df['wer'], c=colors, marker='o') | |
ax[1].vlines(new_ids, ymin=0, ymax=df['wer'], colors='grey', linestyle='dotted', label='Vertical Lines') | |
ax[1].axhline(y=threshold, xmin=0, xmax=len(new_ids), color='r') | |
# ax[0].axhline(y=threshold, color="black") | |
# for i, v in enumerate(df['wer']): | |
# plt.text(str(df['id'][i]).split('.')[0], -2, str(df['id'][i]), ha='center', fontsize=3) | |
ax[1].set_xticklabels(new_ids, rotation=90, fontsize=10) | |
ax[1].tick_params(axis='x', width=20) | |
# ax[1].set_xlim(10, len(df['id']) + 10) | |
plt.tight_layout() | |
pdb.set_trace() | |
# fig.savefig("%s/%s.png"%(Path(sys.argv[1]).parent, sys.argv[1].split('/')[-1]), format='png') | |
fig.savefig("%s.png"%(sys.argv[1]), format='png') | |