altndrr commited on
Commit
7ff77f3
1 Parent(s): 4f4ad71
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha": 0.5,
3
+ "architectures": [
4
+ "CaSEDModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_cased.CaSEDConfig",
8
+ "AutoModel": "modeling_cased.CaSEDModel"
9
+ },
10
+ "index_name": "cc12m",
11
+ "model_type": "cased",
12
+ "retrieval_num_results": 10,
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.29.2"
15
+ }
configuration_cased.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.modeling_utils import PretrainedConfig
2
+
3
+
4
+ class CaSEDConfig(PretrainedConfig):
5
+ """Configuration class for CaSED.
6
+
7
+ Args:
8
+ index_name (str, optional): Name of the index. Defaults to "cc12m".
9
+ alpha (float, optional): Weight of the vision loss. Defaults to 0.5.
10
+ retrieval_num_results (int, optional): Number of results to return. Defaults to 10.
11
+ """
12
+
13
+ model_type = "cased"
14
+ is_composition = True
15
+
16
+ def __init__(
17
+ self,
18
+ index_name: str = "cc12m",
19
+ alpha: float = 0.5,
20
+ retrieval_num_results: int = 10,
21
+ **kwargs,
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.index_name = index_name
25
+ self.alpha = alpha
26
+ self.retrieval_num_results = retrieval_num_results
modeling_cased.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tarfile
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import faiss
7
+ import numpy as np
8
+ import pyarrow as pa
9
+ import requests
10
+ import torch
11
+ from tqdm import tqdm
12
+ from transformers import CLIPModel, CLIPProcessor
13
+ from transformers.modeling_utils import PreTrainedModel
14
+
15
+ from .configuration_cased import CaSEDConfig
16
+ from .transforms_cased import default_vocabulary_transforms
17
+
18
+ DATABASES = {
19
+ "cc12m": {
20
+ "url": "https://storage-cased.alessandroconti.me/cc12m.tar.gz",
21
+ "cache_subdir": "./cc12m/vit-l-14/",
22
+ },
23
+ }
24
+
25
+
26
+ class MetadataProvider:
27
+ """Metadata provider.
28
+
29
+ It uses arrow files to store metadata and retrieve it efficiently.
30
+
31
+ Code reference:
32
+ - https://github.dev/rom1504/clip-retrieval
33
+ """
34
+
35
+ def __init__(self, arrow_folder: Path):
36
+ arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
37
+ self.table = pa.concat_tables(
38
+ [
39
+ pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
40
+ for arrow_file in arrow_files
41
+ ]
42
+ )
43
+
44
+ def get(self, ids: np.ndarray, cols: Optional[list] = None):
45
+ """Get arrow metadata from ids.
46
+
47
+ Args:
48
+ ids (np.ndarray): Ids to retrieve.
49
+ cols (Optional[list], optional): Columns to retrieve. Defaults to None.
50
+ """
51
+ if cols is None:
52
+ cols = self.table.schema.names
53
+ else:
54
+ cols = list(set(self.table.schema.names) & set(cols))
55
+ t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
56
+ return t.select(cols).to_pandas().to_dict("records")
57
+
58
+
59
+ class CaSEDModel(PreTrainedModel):
60
+ """Transformers module for Category Search from External Databases (CaSED).
61
+
62
+ Reference:
63
+ - Conti et al. Vocabulary-free Image Classification. arXiv 2023.
64
+
65
+ Args:
66
+ config (CaSEDConfig): Configuration class for CaSED.
67
+ """
68
+
69
+ config_class = CaSEDConfig
70
+
71
+ def __init__(self, config: CaSEDConfig):
72
+ super().__init__(config)
73
+
74
+ # load CLIP
75
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
76
+ self.vision_encoder = model.vision_model
77
+ self.vision_proj = model.visual_projection
78
+ self.language_encoder = model.text_model
79
+ self.language_proj = model.text_projection
80
+ self.logit_scale = model.logit_scale.exp()
81
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
82
+
83
+ # load transforms
84
+ self.vocabulary_transforms = default_vocabulary_transforms()
85
+
86
+ # set hparams
87
+ self.hparams = {}
88
+ self.hparams["alpha"] = config.alpha
89
+ self.hparams["index_name"] = config.index_name
90
+ self.hparams["retrieval_num_results"] = config.retrieval_num_results
91
+
92
+ # set cache dir
93
+ self.hparams["cache_dir"] = Path(os.path.expanduser("~/.cache/cased"))
94
+ os.makedirs(self.hparams["cache_dir"], exist_ok=True)
95
+
96
+ # download databases
97
+ self.prepare_data()
98
+
99
+ # load faiss indices and metadata providers
100
+ self.resources = {}
101
+ for name, items in DATABASES.items():
102
+ database_path = self.hparams["cache_dir"] / "databases" / items["cache_subdir"]
103
+ text_index_fp = database_path / "text.index"
104
+ metadata_fp = database_path / "metadata/"
105
+
106
+ text_index = faiss.read_index(
107
+ str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
108
+ )
109
+ metadata_provider = MetadataProvider(metadata_fp)
110
+
111
+ self.resources[name] = {
112
+ "device": self.device,
113
+ "model": "ViT-L-14",
114
+ "text_index": text_index,
115
+ "metadata_provider": metadata_provider,
116
+ }
117
+
118
+ def prepare_data(self):
119
+ """Download data if needed."""
120
+ databases_path = Path(self.hparams["cache_dir"]) / "databases"
121
+
122
+ for name, items in DATABASES.items():
123
+ url = items["url"]
124
+ database_path = Path(databases_path, name)
125
+ if database_path.exists():
126
+ continue
127
+
128
+ # download data
129
+ target_path = Path(databases_path, name + ".tar.gz")
130
+ os.makedirs(target_path.parent, exist_ok=True)
131
+ with requests.get(url, stream=True) as r:
132
+ r.raise_for_status()
133
+ total_bytes_size = int(r.headers.get('content-length', 0))
134
+ chunk_size = 8192
135
+ p_bar = tqdm(
136
+ desc="Downloading cc12m index",
137
+ total=total_bytes_size,
138
+ unit='iB',
139
+ unit_scale=True,
140
+ )
141
+ with open(target_path, 'wb') as f:
142
+ for chunk in r.iter_content(chunk_size=chunk_size):
143
+ f.write(chunk)
144
+ p_bar.update(len(chunk))
145
+ p_bar.close()
146
+
147
+ # extract data
148
+ tar = tarfile.open(target_path, "r:gz")
149
+ tar.extractall(target_path.parent)
150
+ tar.close()
151
+ target_path.unlink()
152
+
153
+ @torch.no_grad()
154
+ def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
155
+ """Query the external database index.
156
+
157
+ Args:
158
+ sample_z (torch.Tensor): Sample to query the index.
159
+ """
160
+ # get the index
161
+ resources = self.resources[self.hparams["index_name"]]
162
+ text_index = resources["text_index"]
163
+ metadata_provider = resources["metadata_provider"]
164
+
165
+ # query the index
166
+ sample_z = sample_z.squeeze(0)
167
+ sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
168
+ query_input = sample_z.cpu().detach().numpy().tolist()
169
+ query = np.expand_dims(np.array(query_input).astype("float32"), 0)
170
+
171
+ distances, idxs, _ = text_index.search_and_reconstruct(
172
+ query, self.hparams["retrieval_num_results"]
173
+ )
174
+ results = idxs[0]
175
+ nb_results = np.where(results == -1)[0]
176
+ nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
177
+ indices = results[:nb_results]
178
+ distances = distances[0][:nb_results]
179
+
180
+ if len(distances) == 0:
181
+ return []
182
+
183
+ # get the metadata
184
+ results = []
185
+ metadata = metadata_provider.get(indices[:20], ["caption"])
186
+ for key, (d, i) in enumerate(zip(distances, indices)):
187
+ output = {}
188
+ meta = None if key + 1 > len(metadata) else metadata[key]
189
+ if meta is not None:
190
+ output.update(meta)
191
+ output["id"] = i.item()
192
+ output["similarity"] = d.item()
193
+ results.append(output)
194
+
195
+ # get the captions only
196
+ vocabularies = [result["caption"] for result in results]
197
+
198
+ return vocabularies
199
+
200
+ @torch.no_grad()
201
+ def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor():
202
+ """Forward pass.
203
+
204
+ Args:
205
+ images (dict): Dictionary with the images. The expected keys are:
206
+ - pixel_values (torch.Tensor): Pixel values of the images.
207
+ alpha (Optional[float]): Alpha value for the interpolation.
208
+ """
209
+ # forward the images
210
+ images["pixel_values"] = images["pixel_values"].to(self.device)
211
+ images_z = self.vision_proj(self.vision_encoder(**images)[1])
212
+
213
+ vocabularies, samples_p = [], []
214
+ for image_z in images_z:
215
+ # generate a single text embedding from the unfiltered vocabulary
216
+ vocabulary = self.query_index(image_z)
217
+ text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
218
+ text["input_ids"] = text["input_ids"][:, :77].to(self.device)
219
+ text["attention_mask"] = text["attention_mask"][:, :77].to(self.device)
220
+ text_z = self.language_encoder(**text)[1]
221
+ text_z = self.language_proj(text_z)
222
+
223
+ # filter the vocabulary, embed it, and get its mean embedding
224
+ vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
225
+ text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
226
+ text = {k: v.to(self.device) for k, v in text.items()}
227
+ vocabulary_z = self.language_encoder(**text)[1]
228
+ vocabulary_z = self.language_proj(vocabulary_z)
229
+ vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)
230
+
231
+ # get the image and text predictions
232
+ image_z = image_z / image_z.norm(dim=-1, keepdim=True)
233
+ text_z = text_z / text_z.norm(dim=-1, keepdim=True)
234
+ image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
235
+ text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
236
+
237
+ # average the image and text predictions
238
+ alpha = alpha or self.hparams["alpha"]
239
+ sample_p = alpha * image_p + (1 - alpha) * text_p
240
+
241
+ # save the results
242
+ samples_p.append(sample_p)
243
+ vocabularies.append(vocabulary)
244
+
245
+ # get the scores
246
+ samples_p = torch.stack(samples_p, dim=0)
247
+ scores = sample_p.cpu().tolist()
248
+
249
+ # define the results
250
+ results = {"vocabularies": vocabularies, "scores": scores}
251
+
252
+ return results
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91c5a2012ab49580ef33645ef578ab2eab491ace7ed63e856f9ef340f73e0e9e
3
+ size 1710665929
transforms_cased.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Union
4
+
5
+ import inflect
6
+ import nltk
7
+ from flair.data import Sentence
8
+ from flair.models import SequenceTagger
9
+
10
+ __all__ = [
11
+ "DropFileExtensions",
12
+ "DropNonAlpha",
13
+ "DropShortWords",
14
+ "DropSpecialCharacters",
15
+ "DropTokens",
16
+ "DropURLs",
17
+ "DropWords",
18
+ "FilterPOS",
19
+ "FrequencyMinWordCount",
20
+ "FrequencyTopK",
21
+ "ReplaceSeparators",
22
+ "ToLowercase",
23
+ "ToSingular",
24
+ ]
25
+
26
+
27
+ class BaseTextTransform(ABC):
28
+ """Base class for string transforms."""
29
+
30
+ @abstractmethod
31
+ def __call__(self, text: str):
32
+ raise NotImplementedError
33
+
34
+ def __repr__(self) -> str:
35
+ return f"{self.__class__.__name__}()"
36
+
37
+
38
+ class DropFileExtensions(BaseTextTransform):
39
+ """Remove file extensions from the input text."""
40
+
41
+ def __call__(self, text: str):
42
+ """
43
+ Args:
44
+ text (str): Text to remove file extensions from.
45
+ """
46
+ text = re.sub(r"\.\w+", "", text)
47
+
48
+ return text
49
+
50
+
51
+ class DropNonAlpha(BaseTextTransform):
52
+ """Remove non-alpha words from the input text."""
53
+
54
+ def __call__(self, text: str):
55
+ """
56
+ Args:
57
+ text (str): Text to remove non-alpha words from.
58
+ """
59
+ text = re.sub(r"[^a-zA-Z\s]", "", text)
60
+
61
+ return text
62
+
63
+
64
+ class DropShortWords(BaseTextTransform):
65
+ """Remove short words from the input text.
66
+
67
+ Args:
68
+ min_length (int): Minimum length of words to keep.
69
+ """
70
+
71
+ def __init__(self, min_length) -> None:
72
+ super().__init__()
73
+ self.min_length = min_length
74
+
75
+ def __call__(self, text: str):
76
+ """
77
+ Args:
78
+ text (str): Text to remove short words from.
79
+ """
80
+ text = " ".join([word for word in text.split() if len(word) >= self.min_length])
81
+
82
+ return text
83
+
84
+ def __repr__(self) -> str:
85
+ return f"{self.__class__.__name__}(min_length={self.min_length})"
86
+
87
+
88
+ class DropSpecialCharacters(BaseTextTransform):
89
+ """Remove special characters from the input text.
90
+
91
+ Special characters are defined as any character that is not a word character, whitespace,
92
+ hyphen, period, apostrophe, or ampersand.
93
+ """
94
+
95
+ def __call__(self, text: str):
96
+ """
97
+ Args:
98
+ text (str): Text to remove special characters from.
99
+ """
100
+ text = re.sub(r"[^\w\s\-\.\'\&]", "", text)
101
+
102
+ return text
103
+
104
+
105
+ class DropTokens(BaseTextTransform):
106
+ """Remove tokens from the input text.
107
+
108
+ Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
109
+ """
110
+
111
+ def __call__(self, text: str):
112
+ """
113
+ Args:
114
+ text (str): Text to remove tokens from.
115
+ """
116
+ text = re.sub(r"<[^>]+>", "", text)
117
+
118
+ return text
119
+
120
+
121
+ class DropURLs(BaseTextTransform):
122
+ """Remove URLs from the input text."""
123
+
124
+ def __call__(self, text: str):
125
+ """
126
+ Args:
127
+ text (str): Text to remove URLs from.
128
+ """
129
+ text = re.sub(r"http\S+", "", text)
130
+
131
+ return text
132
+
133
+
134
+ class DropWords(BaseTextTransform):
135
+ """Remove words from the input text.
136
+
137
+ It is case-insensitive and supports singular and plural forms of the words.
138
+ """
139
+
140
+ def __init__(self, words: list[str]) -> None:
141
+ super().__init__()
142
+ self.words = words
143
+ self.pattern = r"\b(?:{})\b".format("|".join(words))
144
+
145
+ def __call__(self, text: str):
146
+ """
147
+ Args:
148
+ text (str): Text to remove words from.
149
+ """
150
+ text = re.sub(self.pattern, "", text, flags=re.IGNORECASE)
151
+
152
+ return text
153
+
154
+ def __repr__(self) -> str:
155
+ return f"{self.__class__.__name__}(pattern={self.pattern})"
156
+
157
+
158
+ class FilterPOS(BaseTextTransform):
159
+ """Filter words by POS tags.
160
+
161
+ Args:
162
+ tags (list): List of POS tags to remove.
163
+ engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk".
164
+ keep_compound_nouns (bool): Whether to keep composed words. Defaults to True.
165
+ """
166
+
167
+ def __init__(self, tags: list, engine: str = "nltk", keep_compound_nouns: bool = True) -> None:
168
+ super().__init__()
169
+ self.tags = tags
170
+ self.engine = engine
171
+ self.keep_compound_nouns = keep_compound_nouns
172
+
173
+ if engine == "nltk":
174
+ nltk.download("averaged_perceptron_tagger", quiet=True)
175
+ nltk.download("punkt", quiet=True)
176
+ self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
177
+ elif engine == "flair":
178
+ self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
179
+
180
+ def __call__(self, text: str):
181
+ """
182
+ Args:
183
+ text (str): Text to remove words with specific POS tags from.
184
+ """
185
+ if self.engine == "nltk":
186
+ word_tags = self.tagger(text)
187
+ text = " ".join([word for word, tag in word_tags if tag not in self.tags])
188
+ elif self.engine == "flair":
189
+ sentence = Sentence(text)
190
+ self.tagger(sentence)
191
+ text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
192
+
193
+ if self.keep_compound_nouns:
194
+ compound_nouns = []
195
+
196
+ if self.engine == "nltk":
197
+ for i in range(len(word_tags) - 1):
198
+ if word_tags[i][1] == "NN" and word_tags[i + 1][1] == "NN":
199
+ # if they are the same word, skip
200
+ if word_tags[i][0] == word_tags[i + 1][0]:
201
+ continue
202
+
203
+ compound_noun = word_tags[i][0] + "_" + word_tags[i + 1][0]
204
+ compound_nouns.append(compound_noun)
205
+ elif self.engine == "flair":
206
+ for i in range(len(sentence.tokens) - 1):
207
+ if sentence.tokens[i].tag == "NN" and sentence.tokens[i + 1].tag == "NN":
208
+ # if they are the same word, skip
209
+ if sentence.tokens[i].text == sentence.tokens[i + 1].text:
210
+ continue
211
+
212
+ compound_noun = sentence.tokens[i].text + "_" + sentence.tokens[i + 1].text
213
+ compound_nouns.append(compound_noun)
214
+
215
+ text = " ".join([text, " ".join(compound_nouns)])
216
+
217
+ return text
218
+
219
+ def __repr__(self) -> str:
220
+ return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})"
221
+
222
+
223
+ class FrequencyMinWordCount(BaseTextTransform):
224
+ """Keep only words that occur more than a minimum number of times in the input text.
225
+
226
+ If the threshold is too strong and no words pass the threshold, the threshold is reduced to
227
+ the most frequent word.
228
+
229
+ Args:
230
+ min_count (int): Minimum number of occurrences of a word to keep.
231
+ """
232
+
233
+ def __init__(self, min_count) -> None:
234
+ super().__init__()
235
+ self.min_count = min_count
236
+
237
+ def __call__(self, text: str):
238
+ """
239
+ Args:
240
+ text (str): Text to remove infrequent words from.
241
+ """
242
+ if self.min_count <= 1:
243
+ return text
244
+
245
+ words = text.split()
246
+ word_counts = {word: words.count(word) for word in words}
247
+
248
+ # if nothing passes the threshold, reduce the threshold to the most frequent word
249
+ max_word_count = max(word_counts.values() or [0])
250
+ min_count = max_word_count if self.min_count > max_word_count else self.min_count
251
+
252
+ text = " ".join([word for word in words if word_counts[word] >= min_count])
253
+
254
+ return text
255
+
256
+ def __repr__(self) -> str:
257
+ return f"{self.__class__.__name__}(min_count={self.min_count})"
258
+
259
+
260
+ class FrequencyTopK(BaseTextTransform):
261
+ """Keep only the top k most frequent words in the input text.
262
+
263
+ In case of a tie, all words with the same count as the last word are kept.
264
+
265
+ Args:
266
+ top_k (int): Number of top words to keep.
267
+ """
268
+
269
+ def __init__(self, top_k: int) -> None:
270
+ super().__init__()
271
+ self.top_k = top_k
272
+
273
+ def __call__(self, text: str):
274
+ """
275
+ Args:
276
+ text (str): Text to remove infrequent words from.
277
+ """
278
+ if self.top_k < 1:
279
+ return text
280
+
281
+ words = text.split()
282
+ word_counts = {word: words.count(word) for word in words}
283
+ top_words = sorted(word_counts, key=word_counts.get, reverse=True)
284
+
285
+ # in case of a tie, keep all words with the same count
286
+ top_words = top_words[: self.top_k]
287
+ top_words = [word for word in top_words if word_counts[word] == word_counts[top_words[-1]]]
288
+
289
+ text = " ".join([word for word in words if word in top_words])
290
+
291
+ return text
292
+
293
+ def __repr__(self) -> str:
294
+ return f"{self.__class__.__name__}(top_k={self.top_k})"
295
+
296
+
297
+ class ReplaceSeparators(BaseTextTransform):
298
+ """Replace underscores and dashes with spaces."""
299
+
300
+ def __call__(self, text: str):
301
+ """
302
+ Args:
303
+ text (str): Text to replace separators in.
304
+ """
305
+ text = re.sub(r"[_\-]", " ", text)
306
+
307
+ return text
308
+
309
+ def __repr__(self) -> str:
310
+ return f"{self.__class__.__name__}()"
311
+
312
+
313
+ class RemoveDuplicates(BaseTextTransform):
314
+ """Remove duplicate words from the input text."""
315
+
316
+ def __call__(self, text: str):
317
+ """
318
+ Args:
319
+ text (str): Text to remove duplicate words from.
320
+ """
321
+ text = " ".join(list(set(text.split())))
322
+
323
+ return text
324
+
325
+
326
+ class TextCompose:
327
+ """Compose several transforms together.
328
+
329
+ It differs from the torchvision.transforms.Compose class in that it applies the transforms to
330
+ a string instead of a PIL Image or Tensor. In addition, it automatically join the list of
331
+ input strings into a single string and splits the output string into a list of words.
332
+
333
+ Args:
334
+ transforms (list): List of transforms to compose.
335
+ """
336
+
337
+ def __init__(self, transforms: list[BaseTextTransform]) -> None:
338
+ self.transforms = transforms
339
+
340
+ def __call__(self, text: Union[str, list[str]]) -> Any:
341
+ if isinstance(text, list):
342
+ text = " ".join(text)
343
+
344
+ for t in self.transforms:
345
+ text = t(text)
346
+ return text.split()
347
+
348
+ def __repr__(self) -> str:
349
+ format_string = self.__class__.__name__ + "("
350
+ for t in self.transforms:
351
+ format_string += "\n"
352
+ format_string += f" {t}"
353
+ format_string += "\n)"
354
+ return format_string
355
+
356
+
357
+ class ToLowercase(BaseTextTransform):
358
+ """Convert text to lowercase."""
359
+
360
+ def __call__(self, text: str):
361
+ """
362
+ Args:
363
+ text (str): Text to convert to lowercase.
364
+ """
365
+ text = text.lower()
366
+
367
+ return text
368
+
369
+
370
+ class ToSingular(BaseTextTransform):
371
+ """Convert plural words to singular form."""
372
+
373
+ def __init__(self) -> None:
374
+ super().__init__()
375
+ self.transform = inflect.engine().singular_noun
376
+
377
+ def __call__(self, text: str):
378
+ """
379
+ Args:
380
+ text (str): Text to convert to singular form.
381
+ """
382
+ words = text.split()
383
+ for i, word in enumerate(words):
384
+ if not word.endswith("s"):
385
+ continue
386
+
387
+ if word[-2:] in ["ss", "us", "is"]:
388
+ continue
389
+
390
+ if word[-3:] in ["ies", "oes"]:
391
+ continue
392
+
393
+ words[i] = self.transform(word) or word
394
+
395
+ text = " ".join(words)
396
+
397
+ return text
398
+
399
+ def __repr__(self) -> str:
400
+ return f"{self.__class__.__name__}()"
401
+
402
+
403
+ def default_vocabulary_transforms() -> TextCompose:
404
+ """Preprocess input text with preprocessing transforms."""
405
+ words_to_drop = [
406
+ "image",
407
+ "photo",
408
+ "picture",
409
+ "thumbnail",
410
+ "logo",
411
+ "symbol",
412
+ "clipart",
413
+ "portrait",
414
+ "painting",
415
+ "illustration",
416
+ "icon",
417
+ "profile",
418
+ ]
419
+ pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"]
420
+
421
+ transforms = []
422
+ transforms.append(DropTokens())
423
+ transforms.append(DropURLs())
424
+ transforms.append(DropSpecialCharacters())
425
+ transforms.append(DropFileExtensions())
426
+ transforms.append(ReplaceSeparators())
427
+ transforms.append(DropShortWords(min_length=3))
428
+ transforms.append(DropNonAlpha())
429
+ transforms.append(ToLowercase())
430
+ transforms.append(ToSingular())
431
+ transforms.append(DropWords(words=words_to_drop))
432
+ transforms.append(FrequencyMinWordCount(min_count=2))
433
+ transforms.append(FilterPOS(tags=pos_tags, engine="flair", keep_compound_nouns=False))
434
+ transforms.append(RemoveDuplicates())
435
+
436
+ transforms = TextCompose(transforms)
437
+
438
+ return transforms