aapot commited on
Commit
4974c8b
1 Parent(s): c902247
config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "text_types": ["title", "description", "transcript"],
3
+ "scalar_features": ["channel_sim"],
4
+ "label_col": "label",
5
+ "cross_encoder_model_name_or_path": "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1"
6
+ }
huggingface_model_wrapper.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import PyTorchModelHubMixin
2
+ from huggingface_hub.constants import PYTORCH_WEIGHTS_NAME
3
+ from huggingface_hub.file_download import hf_hub_download
4
+ from unifiedmodel import RRUM
5
+ import os
6
+ import torch
7
+
8
+
9
+ class YoutubeVideoSimilarityModel(RRUM, PyTorchModelHubMixin):
10
+ """
11
+ Hugging Face `PyTorchModelHubMixin` wrapper for RegretsReporter `RRUM` model.
12
+ This allows loading, using, and saving the model from Hugging Face model hub
13
+ with default Hugging Face methods `from_pretrained` and `save_pretrained`.
14
+ """
15
+ @classmethod
16
+ def _from_pretrained(
17
+ cls,
18
+ model_id,
19
+ revision,
20
+ cache_dir,
21
+ force_download,
22
+ proxies,
23
+ resume_download,
24
+ local_files_only,
25
+ use_auth_token,
26
+ map_location="cpu",
27
+ strict=False,
28
+ **model_kwargs,
29
+ ):
30
+ map_location = torch.device(map_location)
31
+
32
+ if os.path.isdir(model_id):
33
+ print("Loading weights from local directory")
34
+ model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
35
+ else:
36
+ model_file = hf_hub_download(
37
+ repo_id=model_id,
38
+ filename=PYTORCH_WEIGHTS_NAME,
39
+ revision=revision,
40
+ cache_dir=cache_dir,
41
+ force_download=force_download,
42
+ proxies=proxies,
43
+ resume_download=resume_download,
44
+ use_auth_token=use_auth_token,
45
+ local_files_only=local_files_only,
46
+ )
47
+ # convert Huggingface config to RRUM acceptable input parameters
48
+ if "config" in model_kwargs:
49
+ model_kwargs = {**model_kwargs["config"], **model_kwargs}
50
+ del model_kwargs["config"]
51
+ model = cls(**model_kwargs)
52
+
53
+ state_dict = torch.load(model_file, map_location=map_location)
54
+ model.load_state_dict(state_dict, strict=strict)
55
+ model.eval()
56
+
57
+ return model
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0af7ed336a1f40960118bef071cbf01ed5ec165ec71483ec0e2c00abb0f0f098
3
+ size 1411978305
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.4.0
2
+ fastcore==1.5.27
3
+ huggingface_hub==0.9.1
4
+ pandas==1.4.3
5
+ pyarrow==9.0.0
6
+ pytorch_lightning==1.7.6
7
+ torch==1.12.1
8
+ torchmetrics==0.9.3
9
+ transformers==4.22.1
unifiedmodel.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
2
+ import datasets
3
+ import pandas as pd
4
+ import pyarrow
5
+ import pytorch_lightning as pl
6
+ import torchmetrics
7
+ import torch.nn as nn
8
+ import torch
9
+ import types
10
+ import multiprocessing
11
+ from utils.text_cleaning import clean_text_funcs
12
+
13
+
14
+ class RRUMDataset():
15
+ scalar_features = ['channel_sim']
16
+ _image_features = ['regret_thumbnail',
17
+ 'recommendation_thumbnail'] # not used atm
18
+
19
+ def __init__(self, data, with_transcript, cross_encoder_model_name_or_path, label_col="label", label_map=None, balance_label_counts=False, max_length=128, do_train_test_split=False, test_size=0.25, seed=42, keep_video_ids_for_predictions=False, encode_on_the_fly=False, clean_text=False, processing_batch_size=1000, processing_num_proc=1):
20
+ self._with_transcript = with_transcript
21
+ self.tokenizer = AutoTokenizer.from_pretrained(
22
+ cross_encoder_model_name_or_path)
23
+ self.label_col = label_col
24
+ self.label_map = label_map
25
+ self.balance_label_counts = balance_label_counts
26
+ self.max_length = max_length
27
+ self.seed = seed
28
+ self.keep_video_ids_for_predictions = keep_video_ids_for_predictions
29
+ self.clean_text = clean_text
30
+ self.processing_batch_size = processing_batch_size
31
+ self.processing_num_proc = multiprocessing.cpu_count(
32
+ ) if not processing_num_proc else processing_num_proc
33
+
34
+ self.text_types = ['title', 'description'] + \
35
+ (['transcript'] if self._with_transcript else [])
36
+ self._text_features = [
37
+ 'regret_title', 'recommendation_title', 'regret_description',
38
+ 'recommendation_description'] + (['regret_transcript', 'recommendation_transcript'] if self._with_transcript else [])
39
+
40
+ # LOAD DATA INTO DATASET
41
+ self.streaming_dataset = False
42
+ if isinstance(data, pd.DataFrame):
43
+ self.dataset = datasets.Dataset.from_pandas(data)
44
+ elif isinstance(data, types.GeneratorType):
45
+ examples_iterable = datasets.iterable_dataset.ExamplesIterable(
46
+ self._streaming_generate_examples, {"iterable": data})
47
+ self.dataset = datasets.IterableDataset(examples_iterable)
48
+ self._stream_dataset_example = next(iter(self.dataset))
49
+ self._stream_dataset_column_names = list(
50
+ self._stream_dataset_example.keys())
51
+ self.streaming_dataset = True
52
+ elif isinstance(data, pyarrow.Table):
53
+ self.dataset = datasets.Dataset(data)
54
+ else:
55
+ raise ValueError(
56
+ f'Type of data is {type(data)} when pd.DataFrame, pyarrow.Table, or generator of pyarrow.RecordBatch is allowed')
57
+
58
+ # PREPROCESS DATASET
59
+ self._preprocess()
60
+
61
+ # ENCODE DATASET
62
+ self.train_dataset = None
63
+ self.test_dataset = None
64
+ if self.streaming_dataset:
65
+ # IterableDataset doesn't have train_test_split method
66
+ if self.label_col:
67
+ self.train_dataset = self._encode_streaming(self.dataset)
68
+ print('Streaming dataset available in .train_dataset')
69
+ else:
70
+ self.test_dataset = self._encode_streaming(self.dataset)
71
+ print(
72
+ 'Streaming dataset available in .test_dataset because label_col=None')
73
+ else:
74
+ # dataset into train_dataset and/or test_dataset
75
+ if do_train_test_split:
76
+ ds = self.dataset.train_test_split(
77
+ test_size=test_size, shuffle=True, seed=self.seed, stratify_by_column=self.label_col)
78
+ self.train_dataset = ds['train']
79
+ self.test_dataset = ds['test']
80
+ print(
81
+ f'Dataset was splitted into train and test with test_size={test_size}')
82
+ else:
83
+ if self.label_col:
84
+ self.train_dataset = self.dataset
85
+ else:
86
+ self.test_dataset = self.dataset
87
+
88
+ if encode_on_the_fly:
89
+ if self.train_dataset:
90
+ self.train_dataset.set_transform(self._encode_on_the_fly)
91
+ print('On-the-fly encoded dataset available in .train_dataset')
92
+ if self.test_dataset:
93
+ self.test_dataset.set_transform(self._encode_on_the_fly)
94
+ print('On-the-fly encoded dataset available in .test_dataset')
95
+ else:
96
+ if self.train_dataset:
97
+ self.train_dataset = self._encode(self.train_dataset)
98
+ print('Pre-encoded dataset available in .train_dataset')
99
+ if self.test_dataset:
100
+ self.test_dataset = self._encode(self.test_dataset)
101
+ print('Pre-encoded dataset available in .test_dataset')
102
+
103
+ def __len__(self):
104
+ if self.streaming_dataset:
105
+ raise ValueError(
106
+ f'Streaming dataset does not support len() method')
107
+ return len(self.dataset)
108
+
109
+ def __getitem__(self, index):
110
+ if self.streaming_dataset:
111
+ return next(iter(self.dataset))
112
+ return self.dataset[index]
113
+
114
+ def _streaming_generate_examples(self, iterable):
115
+ id_ = 0
116
+ # TODO: make sure GeneratorType is pyarrow.RecordBatch
117
+ if isinstance(iterable, types.GeneratorType):
118
+ for examples in iterable:
119
+ for ex in examples.to_pylist():
120
+ yield id_, ex
121
+ id_ += 1
122
+
123
+ def _preprocess(self):
124
+ if self._with_transcript:
125
+ self.dataset = self.dataset.filter(
126
+ lambda example: example['regret_transcript'] is not None and example['recommendation_transcript'] is not None)
127
+ else:
128
+ self.dataset = self.dataset.filter(
129
+ lambda example: example['regret_transcript'] is None or example['recommendation_transcript'] is None)
130
+ if self.label_col:
131
+ if self.streaming_dataset:
132
+ if self.label_col in self._stream_dataset_column_names and isinstance(self._stream_dataset_example[self.label_col], str):
133
+ if not self.label_map:
134
+ raise ValueError(
135
+ f'"label_map" dict was not provided and is needed to encode string labels for streaming datasets')
136
+ # cast_column method had issues with streaming dataset
137
+ self.dataset = self.dataset.map(
138
+ self._streaming_rename_labels)
139
+ else:
140
+ if self.dataset.features[self.label_col].dtype == 'string':
141
+ if not self.label_map:
142
+ self.label_map = {k: v for v, k in enumerate(
143
+ self.dataset.unique(self.label_col))}
144
+ self.dataset = self.dataset.filter(
145
+ lambda example: example[self.label_col] in self.label_map.keys())
146
+ self.dataset = self.dataset.cast_column(self.label_col, datasets.ClassLabel(
147
+ num_classes=len(self.label_map), names=list(self.label_map.keys())))
148
+
149
+ self.dataset = self.dataset.filter(lambda example: not any(x in [None, ""] for x in [
150
+ example[key] for key in self._text_features + self.scalar_features + ([self.label_col] if self.label_col else [])])) # dropna
151
+
152
+ if self.balance_label_counts and self.label_col and not self.streaming_dataset:
153
+ label_datasets = {}
154
+ for label in list(self.label_map.values()):
155
+ label_dataset = self.dataset.filter(
156
+ lambda example: example[self.label_col] == label)
157
+ label_datasets[len(label_dataset)] = label_dataset
158
+ min_label_count = min(label_datasets)
159
+ sampled_datasets = [dataset.train_test_split(train_size=min_label_count, shuffle=True, seed=self.seed)[
160
+ 'train'] if len(dataset) != min_label_count else dataset for dataset in label_datasets.values()]
161
+ self.dataset = datasets.concatenate_datasets(sampled_datasets)
162
+
163
+ if self.clean_text:
164
+ self.dataset = self.dataset.map(self._clean_text, batched=not self.streaming_dataset,
165
+ batch_size=self.processing_batch_size)
166
+ self.dataset = self.dataset.map(self._truncate_and_strip_text, batched=not self.streaming_dataset,
167
+ batch_size=self.processing_batch_size)
168
+
169
+ def _streaming_rename_labels(self, example):
170
+ # rename labels according to label_map if not already correct labels
171
+ if isinstance(example[self.label_col], list):
172
+ example[self.label_col] = [self.label_map.get(
173
+ ex, None) for ex in example[self.label_col] if ex not in self.label_map.values()]
174
+ elif isinstance(example[self.label_col], str) and example[self.label_col] not in self.label_map.values():
175
+ example[self.label_col] = self.label_map.get(
176
+ example[self.label_col], None)
177
+ else:
178
+ raise ValueError(
179
+ f'Type of example label is {type(example[self.label_col])} when list or string is allowed')
180
+ return example
181
+
182
+ def _clean_text(self, example):
183
+ for feat in self._text_features:
184
+ example[feat] = clean_text_funcs(example[feat])[0] if isinstance(
185
+ example[feat], str) else clean_text_funcs(example[feat])
186
+ return example
187
+
188
+ def _truncate_and_strip_text(self, example):
189
+ # tokenizer will truncate to max_length tokens anyway so to save RAM let's truncate to max_length words already beforehand
190
+ # one word is usually one or more tokens so should be safe to truncate this way without losing information
191
+ for feat in self._text_features:
192
+ if isinstance(example[feat], list):
193
+ example[feat] = [
194
+ ' '.join(text.split()[:self.max_length]).strip() for text in example[feat] if text]
195
+ elif isinstance(example[feat], str):
196
+ example[feat] = ' '.join(example[feat].split()[
197
+ :self.max_length]).strip()
198
+ elif example[feat] is None:
199
+ return None
200
+ else:
201
+ raise ValueError(
202
+ f'Type of example is {type(example[feat])} when list or string is allowed')
203
+ return example
204
+
205
+ def _encode(self, dataset):
206
+ encoded_dataset = None
207
+ for text_type in self.text_types:
208
+ encoded_text_type = dataset.map(lambda regret, recommendation: self.tokenizer(regret, recommendation, padding="max_length", truncation=True, max_length=self.max_length), batched=True,
209
+ batch_size=self.processing_batch_size, num_proc=self.processing_num_proc, input_columns=[f'regret_{text_type}', f'recommendation_{text_type}'], remove_columns=dataset.column_names)
210
+ encoded_text_type = encoded_text_type.rename_columns(
211
+ {col: f'{text_type}_{col}' for col in encoded_text_type.column_names}) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
212
+ if encoded_dataset:
213
+ encoded_dataset = datasets.concatenate_datasets(
214
+ [encoded_dataset, encoded_text_type], axis=1)
215
+ else:
216
+ encoded_dataset = encoded_text_type
217
+
218
+ # copy scalar features and label from original dataset to the encoded dataset
219
+ for scalar_feat in self.scalar_features:
220
+ encoded_dataset = encoded_dataset.add_column(
221
+ name=scalar_feat, column=dataset[scalar_feat])
222
+ if self.label_col:
223
+ encoded_dataset = encoded_dataset.add_column(
224
+ name=self.label_col, column=dataset[self.label_col])
225
+ if self.keep_video_ids_for_predictions:
226
+ for id in ['regret_id', "recommendation_id"]:
227
+ encoded_dataset = encoded_dataset.add_column(
228
+ name=id, column=dataset[id])
229
+
230
+ encoded_dataset.set_format(
231
+ type='torch', columns=encoded_dataset.column_names)
232
+ return encoded_dataset
233
+
234
+ def _encode_streaming(self, dataset):
235
+ encoded_dataset = dataset.map(self._encode_on_the_fly, batched=True,
236
+ batch_size=self.processing_batch_size, remove_columns=list(set(self._stream_dataset_column_names)-set(self.scalar_features + (
237
+ [self.label_col] if self.label_col else []) + (['regret_id', "recommendation_id"] if self.keep_video_ids_for_predictions else [])))) # IterableDataset doesn't have column_names attribute as normal Dataset
238
+ encoded_dataset = encoded_dataset.with_format("torch")
239
+ return encoded_dataset
240
+
241
+ def _encode_on_the_fly(self, batch):
242
+ for text_type in self.text_types:
243
+ encoded_text_type = dict(self.tokenizer(
244
+ batch[f'regret_{text_type}'], batch[f'recommendation_{text_type}'], padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt"))
245
+ for encoded_key in encoded_text_type.copy():
246
+ encoded_text_type[f"{text_type}_{encoded_key}"] = encoded_text_type.pop(encoded_key) if not self.streaming_dataset else encoded_text_type.pop(
247
+ encoded_key).squeeze(0) # e.g. input_ids -> title_input_ids so we have separate input_ids for each text_type
248
+ del batch[f'regret_{text_type}']
249
+ del batch[f'recommendation_{text_type}']
250
+ batch.update(encoded_text_type)
251
+ for scalar_feat in self.scalar_features:
252
+ batch[scalar_feat] = torch.as_tensor(
253
+ batch[scalar_feat]) if not self.streaming_dataset else torch.as_tensor(batch[scalar_feat]).squeeze(0)
254
+ if self.label_col:
255
+ batch[self.label_col] = torch.as_tensor(
256
+ batch[self.label_col]) if not self.streaming_dataset else torch.as_tensor(batch[self.label_col]).squeeze(0)
257
+ return batch
258
+
259
+
260
+ class RRUM(pl.LightningModule):
261
+ def __init__(self, text_types, scalar_features, label_col, cross_encoder_model_name_or_path, optimizer_config=None, freeze_policy=None, pos_weight=None):
262
+ super().__init__()
263
+ self.save_hyperparameters()
264
+ self.text_types = text_types
265
+ self.scalar_features = scalar_features
266
+ self.label_col = label_col
267
+ self.optimizer_config = optimizer_config
268
+ self.cross_encoder_model_name_or_path = cross_encoder_model_name_or_path
269
+ self.cross_encoders = nn.ModuleDict({})
270
+ for t in self.text_types:
271
+ self.cross_encoders[t] = AutoModelForSequenceClassification.from_pretrained(
272
+ self.cross_encoder_model_name_or_path)
273
+ if freeze_policy is not None:
274
+ for xe in self.cross_encoders.values():
275
+ for name, param in xe.named_parameters():
276
+ if freeze_policy(name):
277
+ param.requires_grad = False
278
+ cross_encoder_out_features = list(self.cross_encoders.values())[0](
279
+ torch.randint(1, 2, (1, 2))).logits.size(dim=1)
280
+ self.lin1 = nn.Linear(len(self.cross_encoders) * cross_encoder_out_features +
281
+ len(self.scalar_features), 1)
282
+ self.ac_metric = torchmetrics.Accuracy()
283
+ self.pr_metric = torchmetrics.Precision()
284
+ self.re_metric = torchmetrics.Recall()
285
+ self.auc_metric = torchmetrics.AUROC()
286
+
287
+ if pos_weight:
288
+ self.loss = nn.BCEWithLogitsLoss(
289
+ pos_weight=torch.Tensor([pos_weight]))
290
+ else:
291
+ self.loss = nn.BCEWithLogitsLoss()
292
+
293
+ def forward(self, x):
294
+ cross_logits = {}
295
+ for f in self.text_types:
296
+ inputs = {key.split(f'{f}_')[1]: x[key]
297
+ for key in x if f in key} # e.g. title_input_ids -> input_ids since we have separate input_ids for each text_type
298
+ cross_logits[f] = self.cross_encoders[f](**inputs).logits
299
+ x = torch.cat([*cross_logits.values()] +
300
+ [x[scalar][:, None] for scalar in self.scalar_features],
301
+ 1
302
+ )
303
+ del cross_logits
304
+
305
+ x = self.lin1(x)
306
+ return x
307
+
308
+ def configure_optimizers(self):
309
+ if self.optimizer_config:
310
+ return self.optimizer_config(self)
311
+
312
+ optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)
313
+ scheduler = get_linear_schedule_with_warmup(
314
+ optimizer,
315
+ num_warmup_steps=int(
316
+ self.trainer.estimated_stepping_batches * 0.05),
317
+ num_training_steps=self.trainer.estimated_stepping_batches,
318
+ )
319
+ scheduler = {'scheduler': scheduler,
320
+ 'interval': 'step', 'frequency': 1}
321
+ return [optimizer], [scheduler]
322
+
323
+ def training_step(self, train_batch, batch_idx):
324
+ y = train_batch[self.label_col].unsqueeze(1).float()
325
+ logits = self(train_batch)
326
+ loss = self.loss(logits, y)
327
+ self.log('train_loss', loss)
328
+ return loss
329
+
330
+ def validation_step(self, val_batch, batch_idx):
331
+ y = val_batch[self.label_col].unsqueeze(1).float()
332
+ logits = self(val_batch)
333
+ loss = self.loss(logits, y)
334
+ self.ac_metric(logits, y.int())
335
+ self.pr_metric(logits, y.int())
336
+ self.re_metric(logits, y.int())
337
+ self.auc_metric(logits, y.int())
338
+ self.log('validation_accuracy', self.ac_metric)
339
+ self.log('validation_precision', self.pr_metric)
340
+ self.log('validation_recall', self.re_metric)
341
+ self.log('validation_auc', self.auc_metric)
342
+ self.log('val_loss', loss, prog_bar=True)
343
+
344
+ def validation_epoch_end(self, outputs):
345
+ self.log('validation_accuracy_ep', self.ac_metric)
346
+ self.log('validation_precision_ep', self.pr_metric)
347
+ self.log('validation_recall_ep', self.re_metric)
348
+ self.log('validation_auc_ep', self.auc_metric)
utils/text_cleaning.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastcore.basics import listify
2
+ from fastcore.utils import compose
3
+ import unicodedata
4
+ from string import punctuation
5
+ import html
6
+ from itertools import groupby
7
+ import re
8
+
9
+ control_char_regex = re.compile(r'[\r\n\t]+')
10
+ url_regex = re.compile(
11
+ r'((http|https)\:\/\/)?[a-zA-Z0-9\.\/\?\:@\-_=#]+\.([a-zA-Z]){2,6}([a-zA-Z0-9\.\&\/\?\:@\-_=#])*')
12
+ username_regex = re.compile(r'(^|[^@\w])@(\w{1,15})\b')
13
+
14
+
15
+ def fix_html(text):
16
+ tmp_ls = []
17
+ for e in listify(text):
18
+ e = e.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace('nbsp;', ' ').replace(
19
+ '#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace('<br />', "\n").replace(
20
+ '\\"', '"').replace('<unk>', ' ').replace(' @.@ ', '.').replace(' @-@ ', '-').replace('...', ' …')
21
+ tmp_ls.append(html.unescape(e))
22
+
23
+ text = tmp_ls
24
+ return text
25
+
26
+
27
+ def remove_control_char(text):
28
+ tmp_ls = []
29
+ for e in listify(text):
30
+ tmp_ls.append(re.sub(control_char_regex, '.', e))
31
+
32
+ text = tmp_ls
33
+ return text
34
+
35
+
36
+ def remove_remaining_control_chars(text):
37
+ tmp_ls = []
38
+ for e in listify(text):
39
+ tmp_ls.append(
40
+ ''.join(ch for ch in e if unicodedata.category(ch)[0] != 'C'))
41
+
42
+ text = tmp_ls
43
+ return text
44
+
45
+
46
+ def remove_unicode_symbols(text):
47
+ tmp_ls = []
48
+ for e in listify(text):
49
+ tmp_ls.append(
50
+ ''.join(ch for ch in e if unicodedata.category(ch)[0] != 'So'))
51
+
52
+ text = tmp_ls
53
+ return text
54
+
55
+
56
+ def standardise_punc(text):
57
+ transl_table = dict([(ord(x), ord(y))
58
+ for x, y in zip(u"‘’´“”–-", u"'''\"\"--")])
59
+ tmp_ls = []
60
+ for e in listify(text):
61
+ e = e.translate(transl_table)
62
+ tmp_ls.append(e)
63
+
64
+ text = tmp_ls
65
+ return text
66
+
67
+
68
+ def remove_news_tags(text):
69
+ tmp_ls = []
70
+ for e in listify(text):
71
+ e = re.sub(r"(<[A-Z].+?>)|(</[A-Z].+?>)", "", e)
72
+ tmp_ls.append(e)
73
+
74
+ text = tmp_ls
75
+ return text
76
+
77
+
78
+ def replace_urls(text):
79
+ filler, tmp_ls = '', []
80
+ for e in listify(text):
81
+ e = re.sub(r"(<a.+?>)|(</a>)|(<ref.+?>)", "", e)
82
+ e = re.sub(url_regex, filler, e)
83
+ tmp_ls.append(e)
84
+
85
+ text = tmp_ls
86
+ return text
87
+
88
+
89
+ def replace_usernames(text):
90
+ filler, tmp_ls = '', []
91
+ for e in listify(text):
92
+ occ = e.count('@')
93
+ for _ in range(occ):
94
+ e = e.replace('@<user>', f'{filler}')
95
+ # replace other user handles by filler
96
+ e = re.sub(username_regex, filler, e)
97
+ tmp_ls.append(e)
98
+
99
+ text = tmp_ls
100
+ return text
101
+
102
+
103
+ def remove_duplicate_punctuation(text):
104
+ tmp_ls = []
105
+ for e in listify(text):
106
+ e = re.sub(r'\b(\w+)( \1\b)+', r'\1', e)
107
+ punc = set(punctuation)
108
+ newtext = []
109
+ for k, g in groupby(e):
110
+ if k in punc:
111
+ newtext.append(k)
112
+ else:
113
+ newtext.extend(g)
114
+ e = ''.join(newtext)
115
+ tmp_ls.append(e)
116
+
117
+ text = tmp_ls
118
+ return text
119
+
120
+
121
+ def remove_multi_space(text):
122
+ tmp_ls = []
123
+ for e in listify(text):
124
+ tmp_ls.append(' '.join(e.split()))
125
+
126
+ text = tmp_ls
127
+ return text
128
+
129
+
130
+ clean_text_funcs = compose(*[fix_html, remove_control_char, remove_remaining_control_chars, remove_unicode_symbols,
131
+ standardise_punc, remove_news_tags, replace_urls, replace_usernames, remove_duplicate_punctuation, remove_multi_space])