| |
| """Visualize LAS F1 scores across time periods for three parsing models.""" |
|
|
| import re |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| def parse_scores_file(filepath): |
| """Extract average LAS F1 scores per time period from a scores file.""" |
| with open(filepath, 'r') as f: |
| content = f.read() |
|
|
| scores = {} |
|
|
| |
| period_pattern = r'Average scores for time period (\d{4}-\d{4}).*?\nMetric\t.*?\nUPOS\t.*?\nUAS\t.*?\nLAS\t[\d.]+\t[\d.]+\t([\d.]+)' |
| for match in re.finditer(period_pattern, content, re.DOTALL): |
| period = match.group(1) |
| las_f1 = float(match.group(2)) |
| scores[period] = las_f1 |
|
|
| |
| overall_pattern = r'Overall scores across all time periods.*?\nMetric\t.*?\nUPOS\t.*?\nUAS\t.*?\nLAS\t[\d.]+\t[\d.]+\t([\d.]+)' |
| match = re.search(overall_pattern, content, re.DOTALL) |
| if match: |
| scores['Overall'] = float(match.group(1)) |
|
|
| return scores |
|
|
| def main(): |
| |
| scores_dir = 'eval/scores' |
|
|
| talbanken = parse_scores_file(f'{scores_dir}/scores_talbanken.txt') |
| transformer_silver = parse_scores_file(f'{scores_dir}/scores_transformer_silver.txt') |
| transformer_no_silver = parse_scores_file(f'{scores_dir}/scores_transformer_no_silver.txt') |
|
|
| |
| time_periods = ['1700-1750', '1750-1800', '1800-1850', '1850-1900', '1900-1950', 'Overall'] |
|
|
| |
| talbanken_vals = [talbanken.get(p, 0) for p in time_periods] |
| silver_vals = [transformer_silver.get(p, 0) for p in time_periods] |
| no_silver_vals = [transformer_no_silver.get(p, 0) for p in time_periods] |
|
|
| |
| x = np.arange(len(time_periods)) |
| width = 0.25 |
|
|
| fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
| bars1 = ax.bar(x - width, talbanken_vals, width, label='Talbanken', color='#2ecc71') |
| bars2 = ax.bar(x, silver_vals, width, label='Transformer Silver', color='#3498db') |
| bars3 = ax.bar(x + width, no_silver_vals, width, label='Transformer No Silver', color='#e74c3c') |
|
|
| |
| ax.set_xlabel('Time Period', fontsize=12) |
| ax.set_ylabel('LAS F1 Score', fontsize=12) |
| ax.set_title('LAS F1 Scores by Time Period and Model', fontsize=14) |
| ax.set_xticks(x) |
| ax.set_xticklabels(time_periods, rotation=45, ha='right') |
| ax.legend(loc='upper left') |
| ax.set_ylim(0.4, 0.9) |
|
|
| |
| def add_labels(bars, all_vals, model_idx): |
| """Add labels to bars, bold if this model is the winner for that category.""" |
| for i, bar in enumerate(bars): |
| height = bar.get_height() |
| |
| period_vals = [all_vals[0][i], all_vals[1][i], all_vals[2][i]] |
| is_winner = (model_idx == period_vals.index(max(period_vals))) |
| fontweight = 'bold' if is_winner else 'normal' |
| ax.annotate(f'{height:.3f}', |
| xy=(bar.get_x() + bar.get_width() / 2, height), |
| xytext=(0, 3), |
| textcoords="offset points", |
| ha='center', va='bottom', fontsize=8, rotation=90, |
| fontweight=fontweight) |
|
|
| all_vals = [talbanken_vals, silver_vals, no_silver_vals] |
| add_labels(bars1, all_vals, 0) |
| add_labels(bars2, all_vals, 1) |
| add_labels(bars3, all_vals, 2) |
|
|
| plt.tight_layout() |
| plt.savefig('plot/las_f1_scores_comparison.png', dpi=400) |
| print("Saved plot to plot/las_f1_scores_comparison.png and .pdf") |
| plt.show() |
|
|
| if __name__ == '__main__': |
| main() |
|
|