Aidan Phillips commited on
Commit
f5893dd
·
1 Parent(s): d2375b8

clean up fluency code

Browse files
Files changed (2) hide show
  1. categories/fluency.py +100 -46
  2. scorer.ipynb +11 -12
categories/fluency.py CHANGED
@@ -5,86 +5,126 @@ import numpy as np
5
  import spacy
6
  import wordfreq
7
 
 
 
 
 
8
  tool = language_tool_python.LanguageTool('en-US')
 
 
9
  model_name="distilbert-base-multilingual-cased"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForMaskedLM.from_pretrained(model_name)
12
  model.eval()
 
13
 
 
14
  nlp = spacy.load("en_core_web_sm")
15
 
16
- def __get_word_pr_score(word, lang="en") -> list[float]:
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return -np.log(wordfreq.word_frequency(word, lang) + 1e-12)
18
 
19
- def pseudo_perplexity(text, threshold=20, max_len=128):
20
  """
21
- We want to return
22
- {
23
- "score": normalized value from 0 to 100,
24
- "errors": [
25
- {
26
- "start": word index,
27
- "end": word index,
28
- "message": "error message"
29
- }
30
- ]
31
- }
32
  """
33
- encoding = tokenizer(text, return_tensors="pt", return_offsets_mapping=True)
34
- input_ids = encoding["input_ids"][0]
35
- # print(input_ids)
36
- offset_mapping = encoding["offset_mapping"][0]
37
- # print(offset_mapping)
38
- tokens = tokenizer.convert_ids_to_tokens(input_ids)
39
-
40
- # Group token indices by word based on offset mapping
41
- word_groups = []
42
  current_group = []
43
-
44
  prev_end = None
45
-
46
  for i, (start, end) in enumerate(offset_mapping):
47
  if input_ids[i] in tokenizer.all_special_ids:
48
  continue # skip special tokens like [CLS] and [SEP]
49
-
50
  if prev_end is not None and start > prev_end:
51
  # Word boundary detected → start new group
52
- word_groups.append(current_group)
53
  current_group = [i]
54
  else:
55
  current_group.append(i)
56
-
57
  prev_end = end
58
-
59
  # Append final group
60
  if current_group:
61
- word_groups.append(current_group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
 
 
 
 
 
63
  loss_values = []
64
  for group in word_groups:
 
65
  if group[0] == 0 or group[-1] == len(input_ids) - 1:
66
- continue # skip [CLS] and [SEP]
67
 
 
68
  masked = input_ids.clone()
69
  for i in group:
70
  masked[i] = tokenizer.mask_token_id
71
 
 
72
  with torch.no_grad():
73
  outputs = model(masked.unsqueeze(0))
74
  logits = outputs.logits[0]
75
 
76
  log_probs = []
77
  for i in group:
 
78
  probs = torch.softmax(logits[i], dim=-1)
79
  true_token_id = input_ids[i].item()
80
  prob = probs[true_token_id].item()
 
81
  log_probs.append(np.log(prob + 1e-12))
82
 
 
83
  word_loss = -np.sum(log_probs) / len(log_probs)
 
84
  word = tokenizer.decode(input_ids[group[0]])
85
- word_loss -= 0.6 * __get_word_pr_score(word)
86
  loss_values.append(word_loss)
87
 
 
 
 
88
  errors = []
89
  for i, l in enumerate(loss_values):
90
  if l < threshold:
@@ -92,36 +132,43 @@ def pseudo_perplexity(text, threshold=20, max_len=128):
92
  errors.append({
93
  "start": i,
94
  "end": i,
95
- "message": f"Perplexity {l} over threshold {threshold}"
96
  })
97
 
98
- error_rate = len(errors) / len(loss_values)
99
-
100
  res = {
101
- "score": __grammar_score_from_prob(error_rate),
102
  "errors": errors
103
  }
104
 
105
  return res
106
 
107
- def __fluency_score_from_ppl(ppl, midpoint=8, steepness=0.3):
108
  """
109
- Use a logistic function to map perplexity to 0–100.
110
- Midpoint is the PPL where score is 50.
111
- Steepness controls curve sharpness.
 
 
 
 
 
 
 
 
112
  """
113
- score = 100 / (1 + np.exp(steepness * (ppl - midpoint)))
114
  return round(score, 2)
115
 
116
  def grammar_errors(text) -> tuple[int, list[str]]:
117
  """
 
118
 
119
- Returns
120
- int: number of grammar errors
121
- list: grammar errors
122
- tuple: (start, end, error message)
 
123
  """
124
-
125
  matches = tool.check(text)
126
 
127
  r = []
@@ -221,3 +268,10 @@ def __check_structural_grammar(text):
221
  })
222
 
223
  return issues
 
 
 
 
 
 
 
 
5
  import spacy
6
  import wordfreq
7
 
8
+ # setup global variables on import (bad practice, but whatever)
9
+ #--------------------------------------------------------------
10
+
11
+ # grammar checker
12
  tool = language_tool_python.LanguageTool('en-US')
13
+
14
+ # masked language model and tokenizer from huggingface
15
  model_name="distilbert-base-multilingual-cased"
 
16
  model = AutoModelForMaskedLM.from_pretrained(model_name)
17
  model.eval()
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name) # tokenizer
19
 
20
+ # spacy model for parsing
21
  nlp = spacy.load("en_core_web_sm")
22
 
23
+ def __get_rarity(word, lang="en") -> float:
24
+ """
25
+ Returns the rarity of a word in the given language. word_freq retuns a value
26
+ between 0 and 1, where 1 is the most common word. Therefore, taking the log results
27
+ in a value between 0 (log 1 = 0) and -27.63 (log 1e-12). We then negate it so super
28
+ rare words have a high score and common words have a low score.
29
+
30
+ Parameters:
31
+ word (str): The word to check.
32
+ lang (str): The language to check. Default is "en".
33
+
34
+ Returns:
35
+ float: The rarity of the word.
36
+ """
37
  return -np.log(wordfreq.word_frequency(word, lang) + 1e-12)
38
 
39
+ def __produce_groupings(offset_mapping, input_ids):
40
  """
41
+ Produce groupings of tokens that are part of the same word.
42
+
43
+ Parameters:
44
+ offset_mapping (list): The offset mapping of the tokens.
45
+ input_ids (list): The input ids of the tokens.
46
+
47
+ Returns:
48
+ list: A list of groupings of tokens.
 
 
 
49
  """
50
+ # Produce groupings of tokens that are part of the same word
51
+ res = []
 
 
 
 
 
 
 
52
  current_group = []
 
53
  prev_end = None
 
54
  for i, (start, end) in enumerate(offset_mapping):
55
  if input_ids[i] in tokenizer.all_special_ids:
56
  continue # skip special tokens like [CLS] and [SEP]
 
57
  if prev_end is not None and start > prev_end:
58
  # Word boundary detected → start new group
59
+ res.append(current_group)
60
  current_group = [i]
61
  else:
62
  current_group.append(i)
 
63
  prev_end = end
 
64
  # Append final group
65
  if current_group:
66
+ res.append(current_group)
67
+
68
+ return res
69
+
70
+ def pseudo_perplexity(text, threshold=4, max_len=128):
71
+ """
72
+ Calculate the pseudo-perplexity of a text using a masked language model. Return all
73
+ words that exceed a threshold of "adjusted awkwardness". The threshold is a measure
74
+ in terms of log probability of the word.
75
+
76
+ Parameters:
77
+ text (str): The text to check.
78
+ threshold (float): The threshold for awkwardness. Default is 4.
79
+ max_len (int): The maximum length of the text. Default is 128.
80
+
81
+ Returns:
82
+ dict: A dictionary containing the score and errors.
83
+ """
84
 
85
+ # Tokenize the text and produce groupings
86
+ encoding = tokenizer(text, return_tensors="pt", return_offsets_mapping=True)
87
+ input_ids = encoding["input_ids"][0]
88
+ offset_mapping = encoding["offset_mapping"][0]
89
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
90
+ word_groups = __produce_groupings(offset_mapping, input_ids)
91
+
92
+ # Calculate the loss for each word group
93
  loss_values = []
94
  for group in word_groups:
95
+ # Skip special tokens (CLS and SEP)
96
  if group[0] == 0 or group[-1] == len(input_ids) - 1:
97
+ continue
98
 
99
+ # Mask the word group
100
  masked = input_ids.clone()
101
  for i in group:
102
  masked[i] = tokenizer.mask_token_id
103
 
104
+ # Get the model output distribution
105
  with torch.no_grad():
106
  outputs = model(masked.unsqueeze(0))
107
  logits = outputs.logits[0]
108
 
109
  log_probs = []
110
  for i in group:
111
+ # Get the probability of the true token
112
  probs = torch.softmax(logits[i], dim=-1)
113
  true_token_id = input_ids[i].item()
114
  prob = probs[true_token_id].item()
115
+ # Append the loss of the true token
116
  log_probs.append(np.log(prob + 1e-12))
117
 
118
+ # Calculate the loss for the entire word group
119
  word_loss = -np.sum(log_probs) / len(log_probs)
120
+ # Adjust the loss based on the rarity of the word
121
  word = tokenizer.decode(input_ids[group[0]])
122
+ word_loss -= 0.6 * __get_rarity(word) # subtract rarity (rare words reduce loss)
123
  loss_values.append(word_loss)
124
 
125
+ # Structure the results for output
126
+ average_loss = np.mean(loss_values)
127
+
128
  errors = []
129
  for i, l in enumerate(loss_values):
130
  if l < threshold:
 
132
  errors.append({
133
  "start": i,
134
  "end": i,
135
+ "message": f"Adjusted liklihood {l} over threshold {threshold}"
136
  })
137
 
 
 
138
  res = {
139
+ "score": __fluency_score(average_loss),
140
  "errors": errors
141
  }
142
 
143
  return res
144
 
145
+ def __fluency_score(loss, midpoint=5, steepness=0.3):
146
  """
147
+ Transform the loss into a score from 0 to 100. Steepness controls how quickly the
148
+ score drops as loss increases and midpoint controls the loss at which the score is
149
+ 50.
150
+
151
+ Parameters:
152
+ loss (float): The loss to transform.
153
+ midpoint (float): The loss at which the score is 50. Default is 5.
154
+ steepness (float): The steepness of the curve. Default is 0.3.
155
+
156
+ Returns:
157
+ float: The score from 0 to 100.
158
  """
159
+ score = 100 / (1 + np.exp(steepness * (loss - midpoint)))
160
  return round(score, 2)
161
 
162
  def grammar_errors(text) -> tuple[int, list[str]]:
163
  """
164
+ Check the grammar of a text using a grammar checker and a structural grammar check.
165
 
166
+ Parameters:
167
+ text (str): The text to check.
168
+
169
+ Returns:
170
+ dict: A dictionary containing the score and errors.
171
  """
 
172
  matches = tool.check(text)
173
 
174
  r = []
 
268
  })
269
 
270
  return issues
271
+
272
+
273
+ def main():
274
+ pass
275
+
276
+ if __name__ == "__main__":
277
+ main()
scorer.ipynb CHANGED
@@ -11,14 +11,14 @@
11
  },
12
  {
13
  "cell_type": "code",
14
- "execution_count": 20,
15
  "metadata": {},
16
  "outputs": [
17
  {
18
  "name": "stdout",
19
  "output_type": "stream",
20
  "text": [
21
- "Sentence: The cat sat the quickly up apples banana.\n"
22
  ]
23
  }
24
  ],
@@ -31,23 +31,22 @@
31
  "print(\"Sentence:\", s) # Print the input sentence\n",
32
  "\n",
33
  "err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
34
- "flu = pseudo_perplexity(s, threshold=3.5) # Call the function to execute the fluency checking"
35
  ]
36
  },
37
  {
38
  "cell_type": "code",
39
- "execution_count": 21,
40
  "metadata": {},
41
  "outputs": [
42
  {
43
  "name": "stdout",
44
  "output_type": "stream",
45
  "text": [
46
- "An apostrophe may be missing.: apples banana.\n",
47
- "Perplexity 4.8056646935577145 over threshold 3.5: sat\n",
48
- "Perplexity 4.473408069089179 over threshold 3.5: the\n",
49
- "Perplexity 4.732453441503642 over threshold 3.5: quickly\n",
50
- "Perplexity 5.1115574262487735 over threshold 3.5: apples\n"
51
  ]
52
  }
53
  ],
@@ -61,15 +60,15 @@
61
  },
62
  {
63
  "cell_type": "code",
64
- "execution_count": 22,
65
  "metadata": {},
66
  "outputs": [
67
  {
68
  "name": "stdout",
69
  "output_type": "stream",
70
  "text": [
71
- "87.5 50.0\n",
72
- "Fluency Score: 68.75\n"
73
  ]
74
  }
75
  ],
 
11
  },
12
  {
13
  "cell_type": "code",
14
+ "execution_count": 11,
15
  "metadata": {},
16
  "outputs": [
17
  {
18
  "name": "stdout",
19
  "output_type": "stream",
20
  "text": [
21
+ "Sentence: caveman speak weird few word good\n"
22
  ]
23
  }
24
  ],
 
31
  "print(\"Sentence:\", s) # Print the input sentence\n",
32
  "\n",
33
  "err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
34
+ "flu = pseudo_perplexity(s, threshold=3.25) # Call the function to execute the fluency checking"
35
  ]
36
  },
37
  {
38
  "cell_type": "code",
39
+ "execution_count": 12,
40
  "metadata": {},
41
  "outputs": [
42
  {
43
  "name": "stdout",
44
  "output_type": "stream",
45
  "text": [
46
+ "This sentence does not start with an uppercase letter.: caveman speak\n",
47
+ "Perplexity 4.2750282429106585 over threshold 3.25: caveman\n",
48
+ "Perplexity 5.191700905668536 over threshold 3.25: few\n",
49
+ "Perplexity 3.8370066187600944 over threshold 3.25: good\n"
 
50
  ]
51
  }
52
  ],
 
60
  },
61
  {
62
  "cell_type": "code",
63
+ "execution_count": 10,
64
  "metadata": {},
65
  "outputs": [
66
  {
67
  "name": "stdout",
68
  "output_type": "stream",
69
  "text": [
70
+ "100.0 80.14\n",
71
+ "Fluency Score: 90.07\n"
72
  ]
73
  }
74
  ],