Aidan Phillips commited on
Commit
dc76b04
·
1 Parent(s): b837a10

sussy math works with default sentence

Browse files
Files changed (3) hide show
  1. categories/fluency.py +74 -50
  2. requirements.txt +2 -1
  3. scorer.ipynb +33 -20
categories/fluency.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
  import numpy as np
5
  import spacy
 
6
 
7
  tool = language_tool_python.LanguageTool('en-US')
8
  model_name="distilbert-base-multilingual-cased"
@@ -12,7 +13,10 @@ model.eval()
12
 
13
  nlp = spacy.load("en_core_web_sm")
14
 
15
- def pseudo_perplexity(text, max_len=128):
 
 
 
16
  """
17
  We want to return
18
  {
@@ -26,67 +30,87 @@ def pseudo_perplexity(text, max_len=128):
26
  ]
27
  }
28
  """
29
- input_ids = tokenizer.encode(text, return_tensors="pt")[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- if len(input_ids) > max_len:
32
- raise ValueError(f"Input too long for model (>{max_len} tokens).")
 
33
 
34
  loss_values = []
 
 
 
 
35
 
36
- for i in range(1, len(input_ids) - 1): # skip [CLS] and [SEP]
37
- masked_input = input_ids.clone()
38
- masked_input[i] = tokenizer.mask_token_id
39
 
40
  with torch.no_grad():
41
- outputs = model(masked_input.unsqueeze(0))
42
- logits = outputs.logits[0, i]
43
- probs = torch.softmax(logits, dim=-1)
44
-
45
- true_token_id = input_ids[i].item()
46
- prob_true_token = probs[true_token_id].item()
47
- log_prob = np.log(prob_true_token + 1e-12)
48
- loss_values.append(-log_prob)
 
 
 
 
 
 
 
49
 
50
- # get longest sequence of tokens with perplexity over some threshold
51
- threshold = 12 # Define a perplexity threshold
52
- longest_start, longest_end = 0, 0
53
- current_start, current_end = 0, 0
54
- max_length = 0
55
- curr_loss = 0
56
-
57
- for i, loss in enumerate(loss_values):
58
- if loss > threshold:
59
- if current_start == current_end: # Start a new sequence
60
- current_start = i
61
- current_end = i + 1
62
- curr_loss = loss
63
- else:
64
- if current_end - current_start > max_length:
65
- longest_start, longest_end = current_start, current_end
66
- max_length = current_end - current_start
67
- current_start, current_end = 0, 0
68
-
69
- if current_end - current_start > max_length: # Check the last sequence
70
- longest_start, longest_end = current_start, current_end
71
-
72
- longest_sequence = (longest_start, longest_end)
73
 
74
- ppl = np.exp(np.mean(loss_values))
 
 
75
 
76
  res = {
77
- "score": __fluency_score_from_ppl(ppl),
78
- "errors": [
79
- {
80
- "start": longest_sequence[0],
81
- "end": longest_sequence[1],
82
- "message": f"Perplexity above threshold: {curr_loss}"
83
- }
84
- ]
85
  }
86
 
87
  return res
88
 
89
- def __fluency_score_from_ppl(ppl, midpoint=20, steepness=0.3):
90
  """
91
  Use a logistic function to map perplexity to 0–100.
92
  Midpoint is the PPL where score is 50.
@@ -135,12 +159,12 @@ def grammar_errors(text) -> tuple[int, list[str]]:
135
 
136
  return res
137
 
138
- def __grammar_score_from_prob(error_ratio, steepness=10):
139
  """
140
  Transform the number of errors divided by words into a score from 0 to 100.
141
  Steepness controls how quickly the score drops as errors increase.
142
  """
143
- score = 100 / (1 + np.exp(steepness * error_ratio))
144
  return round(score, 2)
145
 
146
 
 
3
  import torch
4
  import numpy as np
5
  import spacy
6
+ import wordfreq
7
 
8
  tool = language_tool_python.LanguageTool('en-US')
9
  model_name="distilbert-base-multilingual-cased"
 
13
 
14
  nlp = spacy.load("en_core_web_sm")
15
 
16
+ def __get_word_pr_score(word, lang="en") -> list[float]:
17
+ return -np.log(wordfreq.word_frequency(word, lang) + 1e-12)
18
+
19
+ def pseudo_perplexity(text, threshold=20, max_len=128):
20
  """
21
  We want to return
22
  {
 
30
  ]
31
  }
32
  """
33
+ encoding = tokenizer(text, return_tensors="pt", return_offsets_mapping=True)
34
+ input_ids = encoding["input_ids"][0]
35
+ print(input_ids)
36
+ offset_mapping = encoding["offset_mapping"][0]
37
+ print(offset_mapping)
38
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
39
+
40
+ # Group token indices by word based on offset mapping
41
+ word_groups = []
42
+ current_group = []
43
+
44
+ prev_end = None
45
+
46
+ for i, (start, end) in enumerate(offset_mapping):
47
+ if input_ids[i] in tokenizer.all_special_ids:
48
+ continue # skip special tokens like [CLS] and [SEP]
49
+
50
+ if prev_end is not None and start > prev_end:
51
+ # Word boundary detected → start new group
52
+ word_groups.append(current_group)
53
+ current_group = [i]
54
+ else:
55
+ current_group.append(i)
56
+
57
+ prev_end = end
58
 
59
+ # Append final group
60
+ if current_group:
61
+ word_groups.append(current_group)
62
 
63
  loss_values = []
64
+ tok_loss = []
65
+ for group in word_groups:
66
+ if group[0] == 0 or group[-1] == len(input_ids) - 1:
67
+ continue # skip [CLS] and [SEP]
68
 
69
+ masked = input_ids.clone()
70
+ for i in group:
71
+ masked[i] = tokenizer.mask_token_id
72
 
73
  with torch.no_grad():
74
+ outputs = model(masked.unsqueeze(0))
75
+ logits = outputs.logits[0]
76
+
77
+ log_probs = []
78
+ for i in group:
79
+ probs = torch.softmax(logits[i], dim=-1)
80
+ true_token_id = input_ids[i].item()
81
+ prob = probs[true_token_id].item()
82
+ log_probs.append(np.log(prob + 1e-12))
83
+ tok_loss.append(-np.log(prob + 1e-12))
84
+
85
+ word_loss = -np.sum(log_probs) / len(log_probs)
86
+ word = tokenizer.decode(input_ids[group[0]])
87
+ word_loss -= 0.6 * __get_word_pr_score(word)
88
+ loss_values.append(word_loss)
89
 
90
+ print(loss_values)
91
+
92
+ errors = []
93
+ for i, l in enumerate(loss_values):
94
+ if l < threshold:
95
+ continue
96
+ errors.append({
97
+ "start": i,
98
+ "end": i,
99
+ "message": f"Perplexity {l} over threshold {threshold}"
100
+ })
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ print(tok_loss)
103
+ s_ppl = np.mean(tok_loss)
104
+ print(s_ppl)
105
 
106
  res = {
107
+ "score": __fluency_score_from_ppl(s_ppl),
108
+ "errors": errors
 
 
 
 
 
 
109
  }
110
 
111
  return res
112
 
113
+ def __fluency_score_from_ppl(ppl, midpoint=8, steepness=0.3):
114
  """
115
  Use a logistic function to map perplexity to 0–100.
116
  Midpoint is the PPL where score is 50.
 
159
 
160
  return res
161
 
162
+ def __grammar_score_from_prob(error_ratio):
163
  """
164
  Transform the number of errors divided by words into a score from 0 to 100.
165
  Steepness controls how quickly the score drops as errors increase.
166
  """
167
+ score = 100*(1-error_ratio)
168
  return round(score, 2)
169
 
170
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  language_tool_python
2
  transformers
3
- torch
 
 
1
  language_tool_python
2
  transformers
3
+ torch
4
+ wordfreq
scorer.ipynb CHANGED
@@ -4,16 +4,7 @@
4
  "cell_type": "code",
5
  "execution_count": 1,
6
  "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stderr",
10
- "output_type": "stream",
11
- "text": [
12
- "/opt/anaconda3/envs/teach-bs/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
- " from .autonotebook import tqdm as notebook_tqdm\n"
14
- ]
15
- }
16
- ],
17
  "source": [
18
  "from categories.fluency import *"
19
  ]
@@ -27,7 +18,25 @@
27
  "name": "stdout",
28
  "output_type": "stream",
29
  "text": [
30
- "Sentence: The car hit the cone.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ]
32
  }
33
  ],
@@ -40,7 +49,7 @@
40
  "print(\"Sentence:\", s) # Print the input sentence\n",
41
  "\n",
42
  "err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
43
- "flu = pseudo_perplexity(s) # Call the function to execute the fluency checking"
44
  ]
45
  },
46
  {
@@ -52,8 +61,12 @@
52
  "name": "stdout",
53
  "output_type": "stream",
54
  "text": [
55
- "Perplexity above threshold: 0: The\n",
56
- "[{'start': 0, 'end': 0, 'message': 'Perplexity above threshold: 0'}]\n"
 
 
 
 
57
  ]
58
  }
59
  ],
@@ -62,26 +75,26 @@
62
  "\n",
63
  "for e in combined_err:\n",
64
  " substr = \" \".join(s.split(\" \")[e[\"start\"]:e[\"end\"]+1])\n",
65
- " print(f\"{e['message']}: {substr}\") # Print the error messages\n",
66
- "\n",
67
- "print(combined_err)\n"
68
  ]
69
  },
70
  {
71
  "cell_type": "code",
72
- "execution_count": 4,
73
  "metadata": {},
74
  "outputs": [
75
  {
76
  "name": "stdout",
77
  "output_type": "stream",
78
  "text": [
79
- "Fluency Score: 30.0\n"
 
80
  ]
81
  }
82
  ],
83
  "source": [
84
- "fluency_score = 0.6 * err[\"score\"] + 0.4 * flu[\"score\"] # Calculate the fluency score\n",
 
85
  "print(\"Fluency Score:\", fluency_score) # Print the fluency score"
86
  ]
87
  }
 
4
  "cell_type": "code",
5
  "execution_count": 1,
6
  "metadata": {},
7
+ "outputs": [],
 
 
 
 
 
 
 
 
 
8
  "source": [
9
  "from categories.fluency import *"
10
  ]
 
18
  "name": "stdout",
19
  "output_type": "stream",
20
  "text": [
21
+ "Sentence: The cat sat the quickly up apples banana.\n",
22
+ "tensor([ 101, 10117, 41163, 20694, 10105, 23590, 10741, 72894, 11268, 99304,\n",
23
+ " 10219, 119, 102])\n",
24
+ "tensor([[ 0, 0],\n",
25
+ " [ 0, 3],\n",
26
+ " [ 4, 7],\n",
27
+ " [ 8, 11],\n",
28
+ " [12, 15],\n",
29
+ " [16, 23],\n",
30
+ " [24, 26],\n",
31
+ " [27, 30],\n",
32
+ " [30, 33],\n",
33
+ " [34, 38],\n",
34
+ " [38, 40],\n",
35
+ " [40, 41],\n",
36
+ " [ 0, 0]])\n",
37
+ "[np.float64(0.00905743383887514), np.float64(1.1257066968185931), np.float64(4.8056646935577145), np.float64(4.473408069089179), np.float64(4.732453441503642), np.float64(3.028744414819041), np.float64(5.1115574262487735), np.float64(-0.6523823890571343)]\n",
38
+ "[np.float64(1.7636628003080927), np.float64(6.955413759407024), np.float64(10.828562153345375), np.float64(6.228013435558396), np.float64(10.258657658689351), np.float64(6.635744767229443), np.float64(11.163667119285972), np.float64(10.499412826924114), np.float64(11.96113847381264), np.float64(10.010973250156082), np.float64(2.470404176100153)]\n",
39
+ "0.5208035409471965\n"
40
  ]
41
  }
42
  ],
 
49
  "print(\"Sentence:\", s) # Print the input sentence\n",
50
  "\n",
51
  "err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
52
+ "flu = pseudo_perplexity(s, threshold=2.5) # Call the function to execute the fluency checking"
53
  ]
54
  },
55
  {
 
61
  "name": "stdout",
62
  "output_type": "stream",
63
  "text": [
64
+ "An apostrophe may be missing.: apples banana.\n",
65
+ "Perplexity 4.8056646935577145 over threshold 2.5: sat\n",
66
+ "Perplexity 4.473408069089179 over threshold 2.5: the\n",
67
+ "Perplexity 4.732453441503642 over threshold 2.5: quickly\n",
68
+ "Perplexity 3.028744414819041 over threshold 2.5: up\n",
69
+ "Perplexity 5.1115574262487735 over threshold 2.5: apples\n"
70
  ]
71
  }
72
  ],
 
75
  "\n",
76
  "for e in combined_err:\n",
77
  " substr = \" \".join(s.split(\" \")[e[\"start\"]:e[\"end\"]+1])\n",
78
+ " print(f\"{e['message']}: {substr}\") # Print the error messages\n"
 
 
79
  ]
80
  },
81
  {
82
  "cell_type": "code",
83
+ "execution_count": null,
84
  "metadata": {},
85
  "outputs": [
86
  {
87
  "name": "stdout",
88
  "output_type": "stream",
89
  "text": [
90
+ "87.5 99.71\n",
91
+ "Fluency Score: 92.384\n"
92
  ]
93
  }
94
  ],
95
  "source": [
96
+ "fluency_score = 0.7 * err[\"score\"] + 0.3 * flu[\"score\"] # Calculate the fluency score\n",
97
+ "print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
98
  "print(\"Fluency Score:\", fluency_score) # Print the fluency score"
99
  ]
100
  }