|
import pandas as pd |
|
import os |
|
import re |
|
import csv |
|
|
|
def extract_paren(annotation): |
|
ents = [] |
|
for i in range(len(annotation)): |
|
if annotation[i] == "[": |
|
ent = "[" |
|
open_paren = 0 |
|
|
|
for j in range(i+1, len(annotation)): |
|
if annotation[j] == "[": |
|
open_paren += 1 |
|
elif annotation[j] == "]": |
|
if open_paren > 0: |
|
open_paren -= 1 |
|
ent = ent[:len(ent)-3] |
|
else: |
|
|
|
ent += "]" |
|
digit = re.search(r": [0-9]{1,3}", ent) |
|
|
|
if digit: |
|
matches = re.findall(r": [0-9]{1,3}", annotation[:i]) |
|
str_index = annotation[:i].count(" ") - len(matches) |
|
ent += "|" + str(str_index) |
|
ents.append(ent) |
|
break |
|
else: |
|
ent += annotation[j] |
|
return ents |
|
|
|
def create_clusters(ents): |
|
clusters = {} |
|
|
|
for e in ents: |
|
digit_ann = re.search(r": [0-9]{1,3}", e) |
|
if digit_ann: |
|
clean_e = e.replace("[", "").replace("]", "").replace(digit_ann.group(), "") |
|
|
|
digit = re.search(r"[0-9]{1,3}", digit_ann.group()) |
|
digit = int(digit.group()) |
|
|
|
if digit not in clusters: |
|
clusters[digit] = [] |
|
|
|
clusters[digit].append(clean_e) |
|
else: |
|
print("OH NO:", e) |
|
print() |
|
|
|
return clusters |
|
|
|
headers = ["input", "model_output", "model_output_clusters"] |
|
|
|
df = pd.read_csv("results.csv") |
|
|
|
rows = [] |
|
for index, row in df.iterrows(): |
|
annotation = row["model_output"] |
|
|
|
if isinstance(annotation, str): |
|
ann_ents = extract_paren(annotation) |
|
|
|
ann_clusters = {} |
|
if ann_ents: |
|
ann_clusters = create_clusters(ann_ents) |
|
else: |
|
ann_clusters = {} |
|
|
|
|
|
new_row = [row["input"], annotation, str(ann_clusters)] |
|
rows.append(new_row) |
|
|
|
|
|
f = open("cluster_results.csv", "w") |
|
writer = csv.writer(f) |
|
writer.writerow(headers) |
|
writer.writerows(rows) |
|
f.close() |