rmmhicke commited on
Commit
0d1a7d5
1 Parent(s): 9b906a6

Upload 2 files

Browse files
Files changed (2) hide show
  1. get_annotations.py +51 -0
  2. get_ent_clusters.py +82 -0
get_annotations.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import csv
3
+
4
+ from datasets import Dataset, DatasetDict
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ model_name = 't5-literary-coreference'
8
+ device = 'cuda'
9
+
10
+ print("Loading in data")
11
+
12
+ df = pd.read_csv('example_input.csv')
13
+ df = df.sample(frac=1) # Shuffle dataframe contents
14
+
15
+ to_annotate = Dataset.from_pandas(df)
16
+
17
+ speech_excerpts = DatasetDict({"annotate": to_annotate})
18
+
19
+ print("Loading models")
20
+ # Change max_model_length to fit your data
21
+ tokenizer = AutoTokenizer.from_pretrained("t5-3b", model_max_length=500)
22
+
23
+ def preprocess_function(examples, input_text = "input", output_text = "output"):
24
+ model_inputs = tokenizer(examples[input_text], max_length=500, truncation=True)
25
+
26
+ targets = tokenizer(examples[output_text], max_length=500, truncation=True)
27
+
28
+ model_inputs["labels"] = targets["input_ids"]
29
+
30
+ return model_inputs
31
+
32
+ tokenized_speech_excerpts = speech_excerpts.map(preprocess_function, batched=True)
33
+
34
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device=device)
35
+
36
+ print("Begin creating annotations")
37
+ header = ["input", "model_output"]
38
+ rows = []
39
+
40
+ for item in speech_excerpts["annotate"]:
41
+ input_ids = tokenizer(item["input"], return_tensors="pt").input_ids
42
+ result = model.generate(input_ids.to(device=device), max_length = 500)
43
+ rows.append([item["input"], tokenizer.decode(result[0], skip_special_tokens = True)])
44
+
45
+ f = open("results.csv", "w")
46
+ writer = csv.writer(f)
47
+ writer.writerow(header)
48
+ writer.writerows(rows)
49
+ f.close()
50
+
51
+ print("Finished")
get_ent_clusters.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import re
4
+ import csv
5
+
6
+ def extract_paren(annotation):
7
+ ents = []
8
+ for i in range(len(annotation)):
9
+ if annotation[i] == "[":
10
+ ent = "["
11
+ open_paren = 0
12
+
13
+ for j in range(i+1, len(annotation)):
14
+ if annotation[j] == "[":
15
+ open_paren += 1
16
+ elif annotation[j] == "]":
17
+ if open_paren > 0:
18
+ open_paren -= 1
19
+ ent = ent[:len(ent)-3]
20
+ else:
21
+
22
+ ent += "]"
23
+ digit = re.search(r": [0-9]{1,3}", ent)
24
+
25
+ if digit:
26
+ matches = re.findall(r": [0-9]{1,3}", annotation[:i])
27
+ str_index = annotation[:i].count(" ") - len(matches)
28
+ ent += "|" + str(str_index)
29
+ ents.append(ent)
30
+ break
31
+ else:
32
+ ent += annotation[j]
33
+ return ents
34
+
35
+ def create_clusters(ents):
36
+ clusters = {}
37
+
38
+ for e in ents:
39
+ digit_ann = re.search(r": [0-9]{1,3}", e)
40
+ if digit_ann:
41
+ clean_e = e.replace("[", "").replace("]", "").replace(digit_ann.group(), "")
42
+
43
+ digit = re.search(r"[0-9]{1,3}", digit_ann.group())
44
+ digit = int(digit.group())
45
+
46
+ if digit not in clusters:
47
+ clusters[digit] = []
48
+
49
+ clusters[digit].append(clean_e)
50
+ else:
51
+ print("OH NO:", e)
52
+ print()
53
+
54
+ return clusters
55
+
56
+ headers = ["input", "model_output", "model_output_clusters"]
57
+
58
+ df = pd.read_csv("results.csv")
59
+
60
+ rows = []
61
+ for index, row in df.iterrows():
62
+ annotation = row["model_output"]
63
+
64
+ if isinstance(annotation, str):
65
+ ann_ents = extract_paren(annotation)
66
+
67
+ ann_clusters = {}
68
+ if ann_ents:
69
+ ann_clusters = create_clusters(ann_ents)
70
+ else:
71
+ ann_clusters = {}
72
+
73
+
74
+ new_row = [row["input"], annotation, str(ann_clusters)]
75
+ rows.append(new_row)
76
+
77
+
78
+ f = open("cluster_results.csv", "w")
79
+ writer = csv.writer(f)
80
+ writer.writerow(headers)
81
+ writer.writerows(rows)
82
+ f.close()