minko186 commited on
Commit
fcfb880
·
2 Parent(s): cf5bf4c 69b471b

Merge remote-tracking branch 'origin/main' into minko

Browse files
Files changed (5) hide show
  1. analysis.py +172 -72
  2. app.py +53 -0
  3. explainability.py +0 -119
  4. requirements.txt +4 -1
  5. writing_analysis.py +138 -65
analysis.py CHANGED
@@ -1,31 +1,42 @@
1
- import requests
2
- import httpx
3
- import torch
4
- import re
5
- from bs4 import BeautifulSoup
6
- import numpy as np
7
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
- import asyncio
9
- from scipy.special import softmax
10
- from evaluate import load
11
- from datetime import date
12
- import nltk
13
- import fitz
14
- from transformers import GPT2LMHeadModel, GPT2TokenizerFast
15
- import nltk, spacy, subprocess, torch
16
- import plotly.graph_objects as go
17
- import torch.nn.functional as F
18
- import nltk
19
- from unidecode import unidecode
20
- import time
21
  import yaml
22
- import nltk
23
- import os
24
- from explainability import *
25
  import subprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
27
  nltk.download("punkt")
28
  nltk.download("stopwords")
 
 
 
 
 
 
 
29
  with open("config.yaml", "r") as file:
30
  params = yaml.safe_load(file)
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -33,64 +44,153 @@ readability_model_id = params["READABILITY_MODEL_ID"]
33
  gpt2_model = GPT2LMHeadModel.from_pretrained(readability_model_id).to(device)
34
  gpt2_tokenizer = GPT2TokenizerFast.from_pretrained(readability_model_id)
35
 
36
- command = ["python3", "-m", "spacy", "download", "en_core_web_sm"]
37
- subprocess.run(command)
38
- nlp = spacy.load("en_core_web_sm")
 
39
 
40
 
41
  def depth_analysis(input_text):
42
- processed_words = preprocess_text1(input_text)
43
- ttr_value = vocabulary_richness_ttr(processed_words)
44
- gunning_fog = calculate_gunning_fog(input_text)
45
- gunning_fog_norm = normalize(gunning_fog, min_value=0, max_value=20)
46
- words, sentences = preprocess_text2(input_text)
47
- average_sentence_length = calculate_average_sentence_length(sentences)
48
- average_word_length = calculate_average_word_length(words)
49
- average_sentence_length_norm = normalize(
50
- average_sentence_length, min_value=0, max_value=40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  )
52
- average_word_length_norm = normalize(
53
- average_word_length, min_value=0, max_value=8
54
  )
55
- average_tree_depth = calculate_syntactic_tree_depth(nlp, input_text)
56
- average_tree_depth_norm = normalize(
57
- average_tree_depth, min_value=0, max_value=10
58
  )
59
- perplexity = calculate_perplexity(
60
- input_text, gpt2_model, gpt2_tokenizer, device
 
 
61
  )
62
- perplexity_norm = normalize(perplexity, min_value=0, max_value=30)
63
 
64
  features = {
65
- "readability": gunning_fog_norm,
66
- "syntactic tree depth": average_tree_depth_norm,
67
- "vocabulary richness": ttr_value,
68
- "perplexity": perplexity_norm,
69
- "average sentence length": average_sentence_length_norm,
70
- "average word length": average_word_length_norm,
 
 
 
71
  }
72
- fig = go.Figure()
73
- fig.add_trace(
74
- go.Scatterpolar(
75
- r=list(features.values()),
76
- theta=list(features.keys()),
77
- fill="toself",
78
- name="Radar Plot",
79
- )
80
- )
81
- fig.update_layout(
82
- polar=dict(
83
- radialaxis=dict(
84
- visible=True,
85
- range=[0, 100],
86
- )
87
- ),
88
- showlegend=False,
89
- margin=dict(
90
- l=10,
91
- r=20,
92
- b=10,
93
- t=10,
94
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
 
 
 
 
 
 
 
 
 
 
 
 
96
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import yaml
 
 
 
2
  import subprocess
3
+ import nltk
4
+ from nltk import word_tokenize
5
+ from nltk.corpus import cmudict, stopwords
6
+ import spacy
7
+ import torch
8
+ from transformers import GPT2LMHeadModel, GPT2TokenizerFast
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ from matplotlib.patches import Circle, RegularPolygon
13
+ from matplotlib.path import Path
14
+ from matplotlib.projections import register_projection
15
+ from matplotlib.projections.polar import PolarAxes
16
+ from matplotlib.spines import Spine
17
+ from matplotlib.transforms import Affine2D
18
+ from writing_analysis import (
19
+ estimated_slightly_difficult_words_ratio,
20
+ entity_density,
21
+ determiners_frequency,
22
+ punctuation_diversity,
23
+ type_token_ratio,
24
+ calculate_perplexity,
25
+ calculate_syntactic_tree_depth,
26
+ hapax_legomena_ratio,
27
+ mtld,
28
+ )
29
 
30
+ nltk.download("cmudict")
31
  nltk.download("punkt")
32
  nltk.download("stopwords")
33
+ nltk.download("wordnet")
34
+ d = cmudict.dict()
35
+ command = ["python3", "-m", "spacy", "download", "en_core_web_sm"]
36
+ subprocess.run(command)
37
+ nlp = spacy.load("en_core_web_sm")
38
+
39
+
40
  with open("config.yaml", "r") as file:
41
  params = yaml.safe_load(file)
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
44
  gpt2_model = GPT2LMHeadModel.from_pretrained(readability_model_id).to(device)
45
  gpt2_tokenizer = GPT2TokenizerFast.from_pretrained(readability_model_id)
46
 
47
+
48
+ def normalize(value, min_value, max_value):
49
+ normalized_value = ((value - min_value) * 100) / (max_value - min_value)
50
+ return max(0, min(100, normalized_value))
51
 
52
 
53
  def depth_analysis(input_text):
54
+
55
+ usual_ranges = {
56
+ "estimated_slightly_difficult_words_ratio": (
57
+ 0.2273693623058005,
58
+ 0.557383692351033,
59
+ ),
60
+ "entity_density": (-0.07940776754145815, 0.23491038179986615),
61
+ "determiners_frequency": (0.012461059190031154, 0.15700934579439252),
62
+ "punctuation_diversity": (-0.21875, 0.53125),
63
+ "type_token_ratio": (0.33002482852189063, 1.0894414982357028),
64
+ "calculate_perplexity": (-25.110544681549072, 82.4620680809021),
65
+ "calculate_syntactic_tree_depth": (1.8380681818181812, 10.997159090909092),
66
+ "hapax_legomena_ratio": (0.0830971690138207, 1.0302715687215778),
67
+ "mtld": (-84.03125000000001, 248.81875000000002),
68
+ }
69
+
70
+ vocabulary_level = estimated_slightly_difficult_words_ratio(input_text, d)
71
+ entity_ratio = entity_density(input_text, nlp)
72
+ determiner_use = determiners_frequency(input_text, nlp)
73
+ punctuation_variety = punctuation_diversity(input_text)
74
+ sentence_depth = calculate_syntactic_tree_depth(input_text, nlp)
75
+ perplexity = calculate_perplexity(input_text, gpt2_model, gpt2_tokenizer, device)
76
+ lexical_diversity = type_token_ratio(input_text)
77
+ unique_words = hapax_legomena_ratio(input_text)
78
+ vocabulary_stability = mtld(input_text)
79
+
80
+ # normalize between 0 and 100
81
+ vocabulary_level_norm = normalize(
82
+ vocabulary_level, *usual_ranges["estimated_slightly_difficult_words_ratio"]
83
+ )
84
+ entity_ratio_norm = normalize(entity_ratio, *usual_ranges["entity_density"])
85
+ determiner_use_norm = normalize(
86
+ determiner_use, *usual_ranges["determiners_frequency"]
87
  )
88
+ punctuation_variety_norm = normalize(
89
+ punctuation_variety, *usual_ranges["punctuation_diversity"]
90
  )
91
+ lexical_diversity_norm = normalize(
92
+ lexical_diversity, *usual_ranges["type_token_ratio"]
 
93
  )
94
+ unique_words_norm = normalize(unique_words, *usual_ranges["hapax_legomena_ratio"])
95
+ vocabulary_stability_norm = normalize(vocabulary_stability, *usual_ranges["mtld"])
96
+ sentence_depth_norm = normalize(
97
+ sentence_depth, *usual_ranges["calculate_syntactic_tree_depth"]
98
  )
99
+ perplexity_norm = normalize(perplexity, *usual_ranges["calculate_perplexity"])
100
 
101
  features = {
102
+ "Lexical Diversity": lexical_diversity_norm,
103
+ "Vocabulary Level": vocabulary_level_norm,
104
+ "Unique Words": unique_words_norm,
105
+ "Determiner Use": determiner_use_norm,
106
+ "Punctuation Variety": punctuation_variety_norm,
107
+ "Sentence Depth": sentence_depth_norm,
108
+ "Vocabulary Stability": vocabulary_stability_norm,
109
+ "Entity Ratio": entity_ratio_norm,
110
+ "Perplexity": perplexity_norm,
111
  }
112
+
113
+ def radar_factory(num_vars, frame="circle"):
114
+ theta = np.linspace(0, 2 * np.pi, num_vars, endpoint=False)
115
+
116
+ class RadarTransform(PolarAxes.PolarTransform):
117
+ def transform_path_non_affine(self, path):
118
+ if path._interpolation_steps > 1:
119
+ path = path.interpolated(num_vars)
120
+ return Path(self.transform(path.vertices), path.codes)
121
+
122
+ class RadarAxes(PolarAxes):
123
+ name = "radar"
124
+ PolarTransform = RadarTransform
125
+
126
+ def __init__(self, *args, **kwargs):
127
+ super().__init__(*args, **kwargs)
128
+ self.set_theta_zero_location("N")
129
+
130
+ def fill(self, *args, closed=True, **kwargs):
131
+ return super().fill(closed=closed, *args, **kwargs)
132
+
133
+ def plot(self, *args, **kwargs):
134
+ lines = super().plot(*args, **kwargs)
135
+ for line in lines:
136
+ self._close_line(line)
137
+
138
+ def _close_line(self, line):
139
+ x, y = line.get_data()
140
+ if x[0] != x[-1]:
141
+ x = np.append(x, x[0])
142
+ y = np.append(y, y[0])
143
+ line.set_data(x, y)
144
+
145
+ def set_varlabels(self, labels):
146
+ self.set_thetagrids(np.degrees(theta), labels)
147
+
148
+ def _gen_axes_patch(self):
149
+ if frame == "circle":
150
+ return Circle((0.5, 0.5), 0.5)
151
+ elif frame == "polygon":
152
+ return RegularPolygon(
153
+ (0.5, 0.5), num_vars, radius=0.5, edgecolor="k"
154
+ )
155
+
156
+ def _gen_axes_spines(self):
157
+ if frame == "polygon":
158
+ spine = Spine(
159
+ axes=self,
160
+ spine_type="circle",
161
+ path=Path.unit_regular_polygon(num_vars),
162
+ )
163
+ spine.set_transform(
164
+ Affine2D().scale(0.5).translate(0.5, 0.5) + self.transAxes
165
+ )
166
+ return {"polar": spine}
167
+
168
+ register_projection(RadarAxes)
169
+ return theta
170
+
171
+ N = 9
172
+ theta = radar_factory(N, frame="polygon")
173
+ data = features.values()
174
+ labels = features.keys()
175
+ fig, ax = plt.subplots(subplot_kw=dict(projection="radar"), figsize=(7.5, 5))
176
+ ax.plot(theta, data)
177
+ ax.fill(theta, data, alpha=0.4)
178
+ ax.set_varlabels(labels)
179
+
180
+ rgrids = np.linspace(0, 100, num=6)
181
+ ax.set_rgrids(
182
+ rgrids, labels=[f"{round(r)}%" for r in rgrids], fontsize=8, color="black"
183
  )
184
+ ax.grid(True, color="black", linestyle="-", linewidth=0.5, alpha=0.5)
185
+
186
+ for dd, (label, value) in enumerate(zip(labels, data)):
187
+ ax.text(
188
+ theta[dd] + 0.1,
189
+ value + 5,
190
+ f"{value:.0f}",
191
+ horizontalalignment="left",
192
+ verticalalignment="bottom",
193
+ fontsize=8,
194
+ )
195
+
196
  return fig
app.py CHANGED
@@ -232,6 +232,59 @@ with gr.Blocks() as demo:
232
  with gr.Row():
233
  with gr.Column():
234
  writing_analysis_plot = gr.Plot(label="Writing Analysis Plot")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  full_check_btn.click(
237
  fn=main,
 
232
  with gr.Row():
233
  with gr.Column():
234
  writing_analysis_plot = gr.Plot(label="Writing Analysis Plot")
235
+ with gr.Column():
236
+ interpretation = """
237
+ <h2>Writing Analysis Interpretation</h2>
238
+ <ul>
239
+ <li><b>Lexical Diversity</b>: This feature measures the range of unique words used in a text.
240
+ <ul>
241
+ <li>🤖 Higher tends to be AI.</li>
242
+ </ul>
243
+ </li>
244
+ <li><b>Vocabulary Level</b>: This feature assesses the complexity of the words used in a text.
245
+ <ul>
246
+ <li>🤖 Higher tends to be AI.</li>
247
+ </ul>
248
+ </li>
249
+ <li><b>Unique Words</b>: This feature counts the number of words that appear only once within the text.
250
+ <ul>
251
+ <li>🤖 Higher tends to be AI.</li>
252
+ </ul>
253
+ </li>
254
+ <li><b>Determiner Use</b>: This feature tracks the frequency of articles and quantifiers in the text.
255
+ <ul>
256
+ <li>🤖 Higher tends to be AI.</li>
257
+ </ul>
258
+ </li>
259
+ <li><b>Punctuation Variety</b>: This feature indicates the diversity of punctuation marks used in the text.
260
+ <ul>
261
+ <li>👤 Higher tends to be Human.</li>
262
+ </ul>
263
+ </li>
264
+ <li><b>Sentence Depth</b>: This feature evaluates the complexity of the sentence structures used in the text.
265
+ <ul>
266
+ <li>🤖 Higher tends to be AI.</li>
267
+ </ul>
268
+ </li>
269
+ <li><b>Vocabulary Stability</b>: This feature measures the consistency of vocabulary use throughout the text.
270
+ <ul>
271
+ <li>🤖 Higher tends to be AI.</li>
272
+ </ul>
273
+ </li>
274
+ <li><b>Entity Ratio</b>: This feature calculates the proportion of named entities, such as names and places, within the text.
275
+ <ul>
276
+ <li>👤 Higher tends to be Human.</li>
277
+ </ul>
278
+ </li>
279
+ <li><b>Perplexity</b>: This feature assesses the predictability of the text based on the sequence of words.
280
+ <ul>
281
+ <li>👤 Higher tends to be Human.</li>
282
+ </ul>
283
+ </li>
284
+ </ul>
285
+
286
+ """
287
+ gr.HTML(interpretation, label="Interpretation of Writing Analysis")
288
 
289
  full_check_btn.click(
290
  fn=main,
explainability.py DELETED
@@ -1,119 +0,0 @@
1
- import re, textstat
2
- from nltk import FreqDist
3
- from nltk.corpus import stopwords
4
- from nltk.tokenize import word_tokenize, sent_tokenize
5
- import torch
6
- import nltk
7
- from tqdm import tqdm
8
-
9
- nltk.download("punkt")
10
-
11
-
12
- def normalize(value, min_value, max_value):
13
- normalized_value = ((value - min_value) * 100) / (max_value - min_value)
14
- return max(0, min(100, normalized_value))
15
-
16
-
17
- def preprocess_text1(text):
18
- text = text.lower()
19
- text = re.sub(r"[^\w\s]", "", text) # remove punctuation
20
- stop_words = set(stopwords.words("english")) # remove stopwords
21
- words = [word for word in text.split() if word not in stop_words]
22
- words = [word for word in words if not word.isdigit()] # remove numbers
23
- return words
24
-
25
-
26
- def vocabulary_richness_ttr(words):
27
- unique_words = set(words)
28
- ttr = len(unique_words) / len(words) * 100
29
- return ttr
30
-
31
-
32
- def calculate_gunning_fog(text):
33
- """range 0-20"""
34
- gunning_fog = textstat.gunning_fog(text)
35
- return gunning_fog
36
-
37
-
38
- def calculate_automated_readability_index(text):
39
- """range 1-20"""
40
- ari = textstat.automated_readability_index(text)
41
- return ari
42
-
43
-
44
- def calculate_flesch_reading_ease(text):
45
- """range 0-100"""
46
- fre = textstat.flesch_reading_ease(text)
47
- return fre
48
-
49
-
50
- def preprocess_text2(text):
51
- sentences = sent_tokenize(text)
52
- words = [
53
- word.lower()
54
- for sent in sentences
55
- for word in word_tokenize(sent)
56
- if word.isalnum()
57
- ]
58
- stop_words = set(stopwords.words("english"))
59
- words = [word for word in words if word not in stop_words]
60
- return words, sentences
61
-
62
-
63
- def calculate_average_sentence_length(sentences):
64
- """range 0-40 or 50 based on the histogram"""
65
- total_words = sum(len(word_tokenize(sent)) for sent in sentences)
66
- average_sentence_length = total_words / (len(sentences) + 0.0000001)
67
- return average_sentence_length
68
-
69
-
70
- def calculate_average_word_length(words):
71
- """range 0-8 based on the histogram"""
72
- total_characters = sum(len(word) for word in words)
73
- average_word_length = total_characters / (len(words) + 0.0000001)
74
- return average_word_length
75
-
76
-
77
- def calculate_max_depth(sent):
78
- return max(len(list(token.ancestors)) for token in sent)
79
-
80
-
81
- def calculate_syntactic_tree_depth(nlp, text):
82
- """0-10 based on the histogram"""
83
- doc = nlp(text)
84
- sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
85
- average_depth = (
86
- sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
87
- )
88
- return average_depth
89
-
90
-
91
- def calculate_perplexity(text, model, tokenizer, device, stride=512):
92
- """range 0-30 based on the histogram"""
93
- encodings = tokenizer(text, return_tensors="pt")
94
- max_length = model.config.n_positions
95
- seq_len = encodings.input_ids.size(1)
96
-
97
- nlls = []
98
- prev_end_loc = 0
99
- for begin_loc in tqdm(range(0, seq_len, stride)):
100
- end_loc = min(begin_loc + max_length, seq_len)
101
- trg_len = (
102
- end_loc - prev_end_loc
103
- ) # may be different from stride on last loop
104
- input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
105
- target_ids = input_ids.clone()
106
- target_ids[:, :-trg_len] = -100
107
-
108
- with torch.no_grad():
109
- outputs = model(input_ids, labels=target_ids)
110
- neg_log_likelihood = outputs.loss
111
-
112
- nlls.append(neg_log_likelihood)
113
-
114
- prev_end_loc = end_loc
115
- if end_loc == seq_len:
116
- break
117
-
118
- ppl = torch.exp(torch.stack(nlls).mean())
119
- return ppl.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -25,4 +25,7 @@ sentence-transformers
25
  Unidecode
26
  python-dotenv
27
  lime
28
- joblib
 
 
 
 
25
  Unidecode
26
  python-dotenv
27
  lime
28
+ joblib
29
+ emoji==1.6.1
30
+ matplotlib
31
+ seaborn
writing_analysis.py CHANGED
@@ -1,85 +1,153 @@
1
- import re, textstat
2
- from nltk import FreqDist
 
3
  from nltk.corpus import stopwords
4
- from nltk.tokenize import word_tokenize, sent_tokenize
 
5
  import torch
6
- from tqdm import tqdm
7
-
8
-
9
- def normalize(value, min_value, max_value):
10
- normalized_value = ((value - min_value) * 100) / (max_value - min_value)
11
- return max(0, min(100, normalized_value))
12
-
13
- # vocabulary richness
14
- def preprocess_text1(text):
15
- text = text.lower()
16
- text = re.sub(r'[^\w\s]', '', text) # remove punctuation
17
- stop_words = set(stopwords.words('english')) # remove stopwords
18
- words = [word for word in text.split() if word not in stop_words]
19
- words = [word for word in words if not word.isdigit()] # remove numbers
20
- return words
21
-
22
- def vocabulary_richness_ttr(words):
23
- unique_words = set(words)
24
- ttr = len(unique_words) / len(words) * 100
25
- return ttr
26
-
27
- def calculate_gunning_fog(text):
28
- """range 0-20"""
29
- gunning_fog = textstat.gunning_fog(text)
30
- return gunning_fog
31
-
32
- def calculate_automated_readability_index(text):
33
- """range 1-20"""
34
- ari = textstat.automated_readability_index(text)
35
- return ari
36
-
37
- def calculate_flesch_reading_ease(text):
38
- """range 0-100"""
39
- fre = textstat.flesch_reading_ease(text)
40
- return fre
41
-
42
- def preprocess_text2(text):
43
- # tokenize into words and remove punctuation
44
- sentences = sent_tokenize(text)
45
- words = [word.lower() for sent in sentences for word in word_tokenize(sent) if word.isalnum()]
46
- # remove stopwords
47
- stop_words = set(stopwords.words('english'))
48
- words = [word for word in words if word not in stop_words]
49
- return words, sentences
50
-
51
- def calculate_average_sentence_length(sentences):
52
- """range 0-40 or 50 based on the histogram"""
53
- total_words = sum(len(word_tokenize(sent)) for sent in sentences)
54
- average_sentence_length = total_words / (len(sentences) + 0.0000001)
55
- return average_sentence_length
56
-
57
- def calculate_average_word_length(words):
58
- """range 0-8 based on the histogram"""
59
- total_characters = sum(len(word) for word in words)
60
- average_word_length = total_characters / (len(words) + 0.0000001)
61
- return average_word_length
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def calculate_max_depth(sent):
64
  return max(len(list(token.ancestors)) for token in sent)
65
 
66
- def calculate_syntactic_tree_depth(nlp, text):
67
- """0-10 based on the histogram"""
68
  doc = nlp(text)
69
  sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
70
- average_depth = sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
 
 
71
  return average_depth
72
 
73
- # reference: https://huggingface.co/docs/transformers/perplexity
 
74
  def calculate_perplexity(text, model, tokenizer, device, stride=512):
75
- """range 0-30 based on the histogram"""
76
  encodings = tokenizer(text, return_tensors="pt")
77
  max_length = model.config.n_positions
78
  seq_len = encodings.input_ids.size(1)
79
 
80
  nlls = []
81
  prev_end_loc = 0
82
- for begin_loc in tqdm(range(0, seq_len, stride)):
83
  end_loc = min(begin_loc + max_length, seq_len)
84
  trg_len = end_loc - prev_end_loc # may be different from stride on last loop
85
  input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
@@ -88,6 +156,10 @@ def calculate_perplexity(text, model, tokenizer, device, stride=512):
88
 
89
  with torch.no_grad():
90
  outputs = model(input_ids, labels=target_ids)
 
 
 
 
91
  neg_log_likelihood = outputs.loss
92
 
93
  nlls.append(neg_log_likelihood)
@@ -98,3 +170,4 @@ def calculate_perplexity(text, model, tokenizer, device, stride=512):
98
 
99
  ppl = torch.exp(torch.stack(nlls).mean())
100
  return ppl.item()
 
 
1
+ import string
2
+ from collections import Counter
3
+ from nltk import word_tokenize
4
  from nltk.corpus import stopwords
5
+ from nltk.stem import WordNetLemmatizer
6
+ from nltk.probability import FreqDist
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+
10
+ def preprocess_text(text, remove_stopwords=True, use_lemmatization=True):
11
+ tokens = word_tokenize(text.lower())
12
+ tokens = [token for token in tokens if token.isalpha()]
13
+ if remove_stopwords:
14
+ stop_words = set(stopwords.words("english"))
15
+ tokens = [token for token in tokens if token not in stop_words]
16
+ if use_lemmatization:
17
+ lemmatizer = WordNetLemmatizer()
18
+ tokens = [lemmatizer.lemmatize(token) for token in tokens]
19
+ return tokens
20
+
21
+
22
+ def get_special_chars():
23
+ import emoji # Use version emoji==1.6.1, otherwise it won't have UNICODE_EMOJI
24
+
25
+ main_special_characters = string.punctuation + string.digits + string.whitespace
26
+ other_special_characters = (
27
+ "’ “— ™ – •‘œ    ˜ ‚ƒ„’“”–ー一▬…✦�­£​•€«»°·═"
28
+ "×士^˘⇓↓↑←→()§″′´¿−±∈¢ø‚„½¼¾¹²³―⁃,ˌ¸‹›ʺˈʻ¦‐⠀‰……‑≤≥‖"
29
+ "◆●■►▼▲▴∆▻¡★☆✱ːº。¯˜¥ɪ≈†上ン:∼⁄・♡✓⊕․.⋅÷1‟;،、¨ाাी्े◦˚"
30
+ "゜ʼ≖ʼ¤ッツシ℃√!【】‿∞➤~πه۩☛₨➩☻๑٪♥ıॽ《‘©﴿٬?▷Г♫∟™ª₪®「—❖"
31
+ "」﴾》"
32
+ )
33
+ emoji = list(emoji.UNICODE_EMOJI["en"].keys())
34
+ special_characters_default = set(main_special_characters + other_special_characters)
35
+ special_characters_default.update(emoji)
36
+ return special_characters_default
37
+
38
+ special_characters_default = get_special_chars()
39
+
40
+
41
+ # -------------------- Features --------------------
42
+ def syllable_count(word, d):
43
+ return [len(list(y for y in x if y[-1].isdigit())) for x in d.get(word, [])]
44
+
45
+
46
+ def estimated_slightly_difficult_words_ratio(text, d):
47
+ words = word_tokenize(text.lower())
48
+ total_words = len(words)
49
+ # Considering words with 3 or more syllables as difficult
50
+ difficult_count = sum(
51
+ 1 for word in words if sum(1 for _ in syllable_count(word, d)) >= 2
52
+ )
53
+ return difficult_count / total_words if total_words > 0 else 0
54
+
55
+
56
+ # -------------------- Features --------------------
57
+ def entity_density(text, nlp):
58
+ doc = nlp(text)
59
+ return len(doc.ents) / len(doc)
60
+
61
+
62
+ # -------------------- Features --------------------
63
+ def determiners_frequency(text, nlp):
64
+ doc = nlp(text)
65
+ determiners = sum(1 for token in doc if token.pos_ == "DET")
66
+ total_words = len(doc)
67
+ return determiners / total_words if total_words else 0
68
+
69
+
70
+ # -------------------- Features --------------------
71
+ def punctuation_diversity(text):
72
+ punctuation_counts = Counter(
73
+ char for char in text if char in special_characters_default
74
+ )
75
+ diversity_score = (
76
+ len(punctuation_counts) / len(special_characters_default)
77
+ if special_characters_default
78
+ else 0
79
+ )
80
+ return diversity_score
81
+
82
+
83
+ # -------------------- Features --------------------
84
+ def type_token_ratio(text, remove_stopwords=True, use_lemmatization=True):
85
+ tokens = preprocess_text(text, remove_stopwords, use_lemmatization)
86
+ unique_words = set(tokens)
87
+ return len(unique_words) / len(tokens) if tokens else 0
88
+
89
+
90
+ # -------------------- Features --------------------
91
+ def hapax_legomena_ratio(text, remove_stopwords=True, use_lemmatization=True):
92
+ tokens = word_tokenize(text.lower())
93
+ tokens = [token for token in tokens if token.isalpha()]
94
+
95
+ if remove_stopwords:
96
+ stop_words = set(stopwords.words("english"))
97
+ tokens = [token for token in tokens if token not in stop_words]
98
+
99
+ if use_lemmatization:
100
+ lemmatizer = WordNetLemmatizer()
101
+ tokens = [lemmatizer.lemmatize(token) for token in tokens]
102
+
103
+ freq_dist = FreqDist(tokens)
104
+ hapaxes = freq_dist.hapaxes()
105
+ return len(hapaxes) / len(tokens) if tokens else 0
106
+
107
+
108
+ # -------------------- Features --------------------
109
+ def mtld(text, threshold=0.72, remove_stopwords=True, use_lemmatization=True):
110
+ tokens = preprocess_text(text, remove_stopwords, use_lemmatization)
111
+
112
+ def mtld_calc(direction):
113
+ token_length, factor_count = 0, 0
114
+ types = set()
115
+ for token in tokens if direction == "forward" else reversed(tokens):
116
+ types.add(token)
117
+ token_length += 1
118
+ if len(types) / token_length < threshold:
119
+ factor_count += 1
120
+ types = set()
121
+ token_length = 0
122
+ factor_count += 1 # For the last segment, even if it didn't reach the threshold
123
+ return len(tokens) / factor_count if factor_count != 0 else 0
124
+
125
+ return (mtld_calc("forward") + mtld_calc("backward")) / 2
126
+
127
+
128
+ # -------------------- Features --------------------
129
  def calculate_max_depth(sent):
130
  return max(len(list(token.ancestors)) for token in sent)
131
 
132
+
133
+ def calculate_syntactic_tree_depth(text, nlp):
134
  doc = nlp(text)
135
  sentence_depths = [calculate_max_depth(sent) for sent in doc.sents]
136
+ average_depth = (
137
+ sum(sentence_depths) / len(sentence_depths) if sentence_depths else 0
138
+ )
139
  return average_depth
140
 
141
+
142
+ # -------------------- Features --------------------
143
  def calculate_perplexity(text, model, tokenizer, device, stride=512):
 
144
  encodings = tokenizer(text, return_tensors="pt")
145
  max_length = model.config.n_positions
146
  seq_len = encodings.input_ids.size(1)
147
 
148
  nlls = []
149
  prev_end_loc = 0
150
+ for begin_loc in range(0, seq_len, stride):
151
  end_loc = min(begin_loc + max_length, seq_len)
152
  trg_len = end_loc - prev_end_loc # may be different from stride on last loop
153
  input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
 
156
 
157
  with torch.no_grad():
158
  outputs = model(input_ids, labels=target_ids)
159
+
160
+ # loss is calculated using CrossEntropyLoss which averages over valid labels
161
+ # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
162
+ # to the left by 1.
163
  neg_log_likelihood = outputs.loss
164
 
165
  nlls.append(neg_log_likelihood)
 
170
 
171
  ppl = torch.exp(torch.stack(nlls).mean())
172
  return ppl.item()
173
+