Aidan Phillips
commited on
Commit
·
dc76b04
1
Parent(s):
b837a10
sussy math works with default sentence
Browse files- categories/fluency.py +74 -50
- requirements.txt +2 -1
- scorer.ipynb +33 -20
categories/fluency.py
CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import spacy
|
|
|
6 |
|
7 |
tool = language_tool_python.LanguageTool('en-US')
|
8 |
model_name="distilbert-base-multilingual-cased"
|
@@ -12,7 +13,10 @@ model.eval()
|
|
12 |
|
13 |
nlp = spacy.load("en_core_web_sm")
|
14 |
|
15 |
-
def
|
|
|
|
|
|
|
16 |
"""
|
17 |
We want to return
|
18 |
{
|
@@ -26,67 +30,87 @@ def pseudo_perplexity(text, max_len=128):
|
|
26 |
]
|
27 |
}
|
28 |
"""
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
loss_values = []
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
with torch.no_grad():
|
41 |
-
outputs = model(
|
42 |
-
logits = outputs.logits[0
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
current_end = i + 1
|
62 |
-
curr_loss = loss
|
63 |
-
else:
|
64 |
-
if current_end - current_start > max_length:
|
65 |
-
longest_start, longest_end = current_start, current_end
|
66 |
-
max_length = current_end - current_start
|
67 |
-
current_start, current_end = 0, 0
|
68 |
-
|
69 |
-
if current_end - current_start > max_length: # Check the last sequence
|
70 |
-
longest_start, longest_end = current_start, current_end
|
71 |
-
|
72 |
-
longest_sequence = (longest_start, longest_end)
|
73 |
|
74 |
-
|
|
|
|
|
75 |
|
76 |
res = {
|
77 |
-
"score": __fluency_score_from_ppl(
|
78 |
-
"errors":
|
79 |
-
{
|
80 |
-
"start": longest_sequence[0],
|
81 |
-
"end": longest_sequence[1],
|
82 |
-
"message": f"Perplexity above threshold: {curr_loss}"
|
83 |
-
}
|
84 |
-
]
|
85 |
}
|
86 |
|
87 |
return res
|
88 |
|
89 |
-
def __fluency_score_from_ppl(ppl, midpoint=
|
90 |
"""
|
91 |
Use a logistic function to map perplexity to 0–100.
|
92 |
Midpoint is the PPL where score is 50.
|
@@ -135,12 +159,12 @@ def grammar_errors(text) -> tuple[int, list[str]]:
|
|
135 |
|
136 |
return res
|
137 |
|
138 |
-
def __grammar_score_from_prob(error_ratio
|
139 |
"""
|
140 |
Transform the number of errors divided by words into a score from 0 to 100.
|
141 |
Steepness controls how quickly the score drops as errors increase.
|
142 |
"""
|
143 |
-
score = 100
|
144 |
return round(score, 2)
|
145 |
|
146 |
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import spacy
|
6 |
+
import wordfreq
|
7 |
|
8 |
tool = language_tool_python.LanguageTool('en-US')
|
9 |
model_name="distilbert-base-multilingual-cased"
|
|
|
13 |
|
14 |
nlp = spacy.load("en_core_web_sm")
|
15 |
|
16 |
+
def __get_word_pr_score(word, lang="en") -> list[float]:
|
17 |
+
return -np.log(wordfreq.word_frequency(word, lang) + 1e-12)
|
18 |
+
|
19 |
+
def pseudo_perplexity(text, threshold=20, max_len=128):
|
20 |
"""
|
21 |
We want to return
|
22 |
{
|
|
|
30 |
]
|
31 |
}
|
32 |
"""
|
33 |
+
encoding = tokenizer(text, return_tensors="pt", return_offsets_mapping=True)
|
34 |
+
input_ids = encoding["input_ids"][0]
|
35 |
+
print(input_ids)
|
36 |
+
offset_mapping = encoding["offset_mapping"][0]
|
37 |
+
print(offset_mapping)
|
38 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
39 |
+
|
40 |
+
# Group token indices by word based on offset mapping
|
41 |
+
word_groups = []
|
42 |
+
current_group = []
|
43 |
+
|
44 |
+
prev_end = None
|
45 |
+
|
46 |
+
for i, (start, end) in enumerate(offset_mapping):
|
47 |
+
if input_ids[i] in tokenizer.all_special_ids:
|
48 |
+
continue # skip special tokens like [CLS] and [SEP]
|
49 |
+
|
50 |
+
if prev_end is not None and start > prev_end:
|
51 |
+
# Word boundary detected → start new group
|
52 |
+
word_groups.append(current_group)
|
53 |
+
current_group = [i]
|
54 |
+
else:
|
55 |
+
current_group.append(i)
|
56 |
+
|
57 |
+
prev_end = end
|
58 |
|
59 |
+
# Append final group
|
60 |
+
if current_group:
|
61 |
+
word_groups.append(current_group)
|
62 |
|
63 |
loss_values = []
|
64 |
+
tok_loss = []
|
65 |
+
for group in word_groups:
|
66 |
+
if group[0] == 0 or group[-1] == len(input_ids) - 1:
|
67 |
+
continue # skip [CLS] and [SEP]
|
68 |
|
69 |
+
masked = input_ids.clone()
|
70 |
+
for i in group:
|
71 |
+
masked[i] = tokenizer.mask_token_id
|
72 |
|
73 |
with torch.no_grad():
|
74 |
+
outputs = model(masked.unsqueeze(0))
|
75 |
+
logits = outputs.logits[0]
|
76 |
+
|
77 |
+
log_probs = []
|
78 |
+
for i in group:
|
79 |
+
probs = torch.softmax(logits[i], dim=-1)
|
80 |
+
true_token_id = input_ids[i].item()
|
81 |
+
prob = probs[true_token_id].item()
|
82 |
+
log_probs.append(np.log(prob + 1e-12))
|
83 |
+
tok_loss.append(-np.log(prob + 1e-12))
|
84 |
+
|
85 |
+
word_loss = -np.sum(log_probs) / len(log_probs)
|
86 |
+
word = tokenizer.decode(input_ids[group[0]])
|
87 |
+
word_loss -= 0.6 * __get_word_pr_score(word)
|
88 |
+
loss_values.append(word_loss)
|
89 |
|
90 |
+
print(loss_values)
|
91 |
+
|
92 |
+
errors = []
|
93 |
+
for i, l in enumerate(loss_values):
|
94 |
+
if l < threshold:
|
95 |
+
continue
|
96 |
+
errors.append({
|
97 |
+
"start": i,
|
98 |
+
"end": i,
|
99 |
+
"message": f"Perplexity {l} over threshold {threshold}"
|
100 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
+
print(tok_loss)
|
103 |
+
s_ppl = np.mean(tok_loss)
|
104 |
+
print(s_ppl)
|
105 |
|
106 |
res = {
|
107 |
+
"score": __fluency_score_from_ppl(s_ppl),
|
108 |
+
"errors": errors
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
}
|
110 |
|
111 |
return res
|
112 |
|
113 |
+
def __fluency_score_from_ppl(ppl, midpoint=8, steepness=0.3):
|
114 |
"""
|
115 |
Use a logistic function to map perplexity to 0–100.
|
116 |
Midpoint is the PPL where score is 50.
|
|
|
159 |
|
160 |
return res
|
161 |
|
162 |
+
def __grammar_score_from_prob(error_ratio):
|
163 |
"""
|
164 |
Transform the number of errors divided by words into a score from 0 to 100.
|
165 |
Steepness controls how quickly the score drops as errors increase.
|
166 |
"""
|
167 |
+
score = 100*(1-error_ratio)
|
168 |
return round(score, 2)
|
169 |
|
170 |
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
language_tool_python
|
2 |
transformers
|
3 |
-
torch
|
|
|
|
1 |
language_tool_python
|
2 |
transformers
|
3 |
+
torch
|
4 |
+
wordfreq
|
scorer.ipynb
CHANGED
@@ -4,16 +4,7 @@
|
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
-
"outputs": [
|
8 |
-
{
|
9 |
-
"name": "stderr",
|
10 |
-
"output_type": "stream",
|
11 |
-
"text": [
|
12 |
-
"/opt/anaconda3/envs/teach-bs/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
-
]
|
15 |
-
}
|
16 |
-
],
|
17 |
"source": [
|
18 |
"from categories.fluency import *"
|
19 |
]
|
@@ -27,7 +18,25 @@
|
|
27 |
"name": "stdout",
|
28 |
"output_type": "stream",
|
29 |
"text": [
|
30 |
-
"Sentence: The
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
]
|
32 |
}
|
33 |
],
|
@@ -40,7 +49,7 @@
|
|
40 |
"print(\"Sentence:\", s) # Print the input sentence\n",
|
41 |
"\n",
|
42 |
"err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
|
43 |
-
"flu = pseudo_perplexity(s) # Call the function to execute the fluency checking"
|
44 |
]
|
45 |
},
|
46 |
{
|
@@ -52,8 +61,12 @@
|
|
52 |
"name": "stdout",
|
53 |
"output_type": "stream",
|
54 |
"text": [
|
55 |
-
"
|
56 |
-
"
|
|
|
|
|
|
|
|
|
57 |
]
|
58 |
}
|
59 |
],
|
@@ -62,26 +75,26 @@
|
|
62 |
"\n",
|
63 |
"for e in combined_err:\n",
|
64 |
" substr = \" \".join(s.split(\" \")[e[\"start\"]:e[\"end\"]+1])\n",
|
65 |
-
" print(f\"{e['message']}: {substr}\") # Print the error messages\n"
|
66 |
-
"\n",
|
67 |
-
"print(combined_err)\n"
|
68 |
]
|
69 |
},
|
70 |
{
|
71 |
"cell_type": "code",
|
72 |
-
"execution_count":
|
73 |
"metadata": {},
|
74 |
"outputs": [
|
75 |
{
|
76 |
"name": "stdout",
|
77 |
"output_type": "stream",
|
78 |
"text": [
|
79 |
-
"
|
|
|
80 |
]
|
81 |
}
|
82 |
],
|
83 |
"source": [
|
84 |
-
"fluency_score = 0.
|
|
|
85 |
"print(\"Fluency Score:\", fluency_score) # Print the fluency score"
|
86 |
]
|
87 |
}
|
|
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"source": [
|
9 |
"from categories.fluency import *"
|
10 |
]
|
|
|
18 |
"name": "stdout",
|
19 |
"output_type": "stream",
|
20 |
"text": [
|
21 |
+
"Sentence: The cat sat the quickly up apples banana.\n",
|
22 |
+
"tensor([ 101, 10117, 41163, 20694, 10105, 23590, 10741, 72894, 11268, 99304,\n",
|
23 |
+
" 10219, 119, 102])\n",
|
24 |
+
"tensor([[ 0, 0],\n",
|
25 |
+
" [ 0, 3],\n",
|
26 |
+
" [ 4, 7],\n",
|
27 |
+
" [ 8, 11],\n",
|
28 |
+
" [12, 15],\n",
|
29 |
+
" [16, 23],\n",
|
30 |
+
" [24, 26],\n",
|
31 |
+
" [27, 30],\n",
|
32 |
+
" [30, 33],\n",
|
33 |
+
" [34, 38],\n",
|
34 |
+
" [38, 40],\n",
|
35 |
+
" [40, 41],\n",
|
36 |
+
" [ 0, 0]])\n",
|
37 |
+
"[np.float64(0.00905743383887514), np.float64(1.1257066968185931), np.float64(4.8056646935577145), np.float64(4.473408069089179), np.float64(4.732453441503642), np.float64(3.028744414819041), np.float64(5.1115574262487735), np.float64(-0.6523823890571343)]\n",
|
38 |
+
"[np.float64(1.7636628003080927), np.float64(6.955413759407024), np.float64(10.828562153345375), np.float64(6.228013435558396), np.float64(10.258657658689351), np.float64(6.635744767229443), np.float64(11.163667119285972), np.float64(10.499412826924114), np.float64(11.96113847381264), np.float64(10.010973250156082), np.float64(2.470404176100153)]\n",
|
39 |
+
"0.5208035409471965\n"
|
40 |
]
|
41 |
}
|
42 |
],
|
|
|
49 |
"print(\"Sentence:\", s) # Print the input sentence\n",
|
50 |
"\n",
|
51 |
"err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
|
52 |
+
"flu = pseudo_perplexity(s, threshold=2.5) # Call the function to execute the fluency checking"
|
53 |
]
|
54 |
},
|
55 |
{
|
|
|
61 |
"name": "stdout",
|
62 |
"output_type": "stream",
|
63 |
"text": [
|
64 |
+
"An apostrophe may be missing.: apples banana.\n",
|
65 |
+
"Perplexity 4.8056646935577145 over threshold 2.5: sat\n",
|
66 |
+
"Perplexity 4.473408069089179 over threshold 2.5: the\n",
|
67 |
+
"Perplexity 4.732453441503642 over threshold 2.5: quickly\n",
|
68 |
+
"Perplexity 3.028744414819041 over threshold 2.5: up\n",
|
69 |
+
"Perplexity 5.1115574262487735 over threshold 2.5: apples\n"
|
70 |
]
|
71 |
}
|
72 |
],
|
|
|
75 |
"\n",
|
76 |
"for e in combined_err:\n",
|
77 |
" substr = \" \".join(s.split(\" \")[e[\"start\"]:e[\"end\"]+1])\n",
|
78 |
+
" print(f\"{e['message']}: {substr}\") # Print the error messages\n"
|
|
|
|
|
79 |
]
|
80 |
},
|
81 |
{
|
82 |
"cell_type": "code",
|
83 |
+
"execution_count": null,
|
84 |
"metadata": {},
|
85 |
"outputs": [
|
86 |
{
|
87 |
"name": "stdout",
|
88 |
"output_type": "stream",
|
89 |
"text": [
|
90 |
+
"87.5 99.71\n",
|
91 |
+
"Fluency Score: 92.384\n"
|
92 |
]
|
93 |
}
|
94 |
],
|
95 |
"source": [
|
96 |
+
"fluency_score = 0.7 * err[\"score\"] + 0.3 * flu[\"score\"] # Calculate the fluency score\n",
|
97 |
+
"print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
|
98 |
"print(\"Fluency Score:\", fluency_score) # Print the fluency score"
|
99 |
]
|
100 |
}
|