altndrr commited on
Commit
50c0cb3
1 Parent(s): 8c43e37

Use altndrr/cased model

Browse files
app.py CHANGED
@@ -2,8 +2,8 @@ from typing import Optional
2
 
3
  import gradio as gr
4
  import torch
5
-
6
- from src.nn import CaSED
7
 
8
  PAPER_TITLE = "Vocabulary-free Image Classification"
9
  PAPER_DESCRIPTION = """
@@ -37,14 +37,17 @@ To assign a label to an image, we:
37
  """
38
  PAPER_URL = "https://arxiv.org/abs/2306.00917"
39
 
40
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
42
- model = CaSED().to(DEVICE).eval()
 
 
43
 
44
 
45
  def vic(filename: str, alpha: Optional[float] = None):
46
- # get the outputs of the model
47
- vocabulary, scores = model(filename, alpha=alpha)
 
 
48
  confidences = dict(zip(vocabulary, scores))
49
 
50
  return confidences
 
2
 
3
  import gradio as gr
4
  import torch
5
+ from PIL import Image
6
+ from transformers import AutoModel, CLIPProcessor
7
 
8
  PAPER_TITLE = "Vocabulary-free Image Classification"
9
  PAPER_DESCRIPTION = """
 
37
  """
38
  PAPER_URL = "https://arxiv.org/abs/2306.00917"
39
 
 
40
 
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model = AutoModel.from_pretrained("altndrr/cased", trust_remote_code=True).to(device)
43
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
44
 
45
 
46
  def vic(filename: str, alpha: Optional[float] = None):
47
+ images = processor(images=[Image.open(filename)], return_tensors="pt", padding=True)
48
+ outputs = model(images, alpha=alpha)
49
+ vocabulary = outputs["vocabularies"][0]
50
+ scores = outputs["scores"][0]
51
  confidences = dict(zip(vocabulary, scores))
52
 
53
  return confidences
artifacts/models/databases/.gitkeep DELETED
File without changes
artifacts/models/retrieval/indices.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
3
- }
 
 
 
 
src/nn.py DELETED
@@ -1,186 +0,0 @@
1
- import json
2
- import tarfile
3
- from pathlib import Path
4
- from typing import Optional
5
-
6
- import faiss
7
- import gdown
8
- import numpy as np
9
- import torch
10
- from PIL import Image
11
- from transformers import CLIPModel, CLIPProcessor
12
-
13
- from src.retrieval import ArrowMetadataProvider
14
- from src.transforms import TextCompose, default_vocabulary_transforms
15
-
16
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
-
19
- RETRIEVAL_DATABASES = {
20
- "cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
21
- }
22
-
23
-
24
- class CaSED(torch.nn.Module):
25
- """Torch module for Category Search from External Databases (CaSED).
26
-
27
- Args:
28
- index_name (str): Name of the faiss index to use.
29
- vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
30
-
31
- Extra hparams:
32
- alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
33
- artifact_dir (str): Path to the directory where the databases are stored. Defaults to
34
- "artifacts/".
35
- retrieval_num_results (int): Number of results to return. Defaults to 10.
36
- """
37
-
38
- def __init__(
39
- self,
40
- index_name: str = "ViT-L-14_CC12M",
41
- vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
42
- **kwargs,
43
- ):
44
- super().__init__()
45
-
46
- # load CLIP
47
- model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(DEVICE)
48
- self.index_name = index_name
49
- self.vocabulary_transforms = vocabulary_transforms
50
- self.vision_encoder = model.vision_model
51
- self.vision_proj = model.visual_projection
52
- self.language_encoder = model.text_model
53
- self.language_proj = model.text_projection
54
- self.logit_scale = model.logit_scale.exp()
55
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
56
-
57
- # set hparams
58
- kwargs["alpha"] = kwargs.get("alpha", 0.5)
59
- kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
60
- kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
61
- self.hparams = kwargs
62
-
63
- # download databases
64
- self.prepare_data()
65
-
66
- # load faiss indices and metadata providers
67
- indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
68
- indices_fp = indices_list_dir / "indices.json"
69
- self.indices = json.load(open(indices_fp))
70
- self.resources = {}
71
- for name, index_fp in self.indices.items():
72
- text_index_fp = Path(index_fp) / "text.index"
73
- metadata_fp = Path(index_fp) / "metadata/"
74
-
75
- text_index = faiss.read_index(
76
- str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
77
- )
78
- metadata_provider = ArrowMetadataProvider(metadata_fp)
79
-
80
- self.resources[name] = {
81
- "device": DEVICE,
82
- "model": "ViT-L-14",
83
- "text_index": text_index,
84
- "metadata_provider": metadata_provider,
85
- }
86
-
87
- def prepare_data(self):
88
- """Download data if needed."""
89
- databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"
90
-
91
- for name, url in RETRIEVAL_DATABASES.items():
92
- database_path = Path(databases_path, name)
93
- if database_path.exists():
94
- continue
95
-
96
- # download data
97
- target_path = Path(databases_path, name + ".tar.gz")
98
- try:
99
- gdown.download(url, str(target_path), quiet=False)
100
- tar = tarfile.open(target_path, "r:gz")
101
- tar.extractall(target_path.parent)
102
- tar.close()
103
- target_path.unlink()
104
- except FileNotFoundError:
105
- print(f"Could not download {url}.")
106
- print(f"Please download it manually and place it in {target_path.parent}.")
107
-
108
- @torch.no_grad()
109
- def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
110
- # get the index
111
- resources = self.resources[self.index_name]
112
- text_index = resources["text_index"]
113
- metadata_provider = resources["metadata_provider"]
114
-
115
- # query the index
116
- sample_z = sample_z.squeeze(0)
117
- sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
118
- query_input = sample_z.cpu().detach().numpy().tolist()
119
- query = np.expand_dims(np.array(query_input).astype("float32"), 0)
120
-
121
- distances, idxs, _ = text_index.search_and_reconstruct(
122
- query, self.hparams["retrieval_num_results"]
123
- )
124
- results = idxs[0]
125
- nb_results = np.where(results == -1)[0]
126
- nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
127
- indices = results[:nb_results]
128
- distances = distances[0][:nb_results]
129
-
130
- if len(distances) == 0:
131
- return []
132
-
133
- # get the metadata
134
- results = []
135
- metadata = metadata_provider.get(indices[:20], ["caption"])
136
- for key, (d, i) in enumerate(zip(distances, indices)):
137
- output = {}
138
- meta = None if key + 1 > len(metadata) else metadata[key]
139
- if meta is not None:
140
- output.update(meta)
141
- output["id"] = i.item()
142
- output["similarity"] = d.item()
143
- results.append(output)
144
-
145
- # get the captions only
146
- vocabularies = [result["caption"] for result in results]
147
-
148
- return vocabularies
149
-
150
- @torch.no_grad()
151
- def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
152
- # forward the image
153
- image = self.processor(images=Image.open(image_fp), return_tensors="pt")
154
- image["pixel_values"] = image["pixel_values"].to(DEVICE)
155
- image_z = self.vision_proj(self.vision_encoder(**image)[1])
156
-
157
- # generate a single text embedding from the unfiltered vocabulary
158
- vocabulary = self.query_index(image_z)
159
- text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
160
- text["input_ids"] = text["input_ids"][:, :77].to(DEVICE)
161
- text["attention_mask"] = text["attention_mask"][:, :77].to(DEVICE)
162
- text_z = self.language_encoder(**text)[1]
163
- text_z = self.language_proj(text_z)
164
-
165
- # filter the vocabulary, embed it, and get its mean embedding
166
- vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
167
- text = self.processor(text=vocabulary, return_tensors="pt", padding=True)
168
- text = {k: v.to(DEVICE) for k, v in text.items()}
169
- vocabulary_z = self.language_encoder(**text)[1]
170
- vocabulary_z = self.language_proj(vocabulary_z)
171
- vocabulary_z = vocabulary_z / vocabulary_z.norm(dim=-1, keepdim=True)
172
-
173
- # get the image and text predictions
174
- image_z = image_z / image_z.norm(dim=-1, keepdim=True)
175
- text_z = text_z / text_z.norm(dim=-1, keepdim=True)
176
- image_p = (torch.matmul(image_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
177
- text_p = (torch.matmul(text_z, vocabulary_z.T) * self.logit_scale).softmax(dim=-1)
178
-
179
- # average the image and text predictions
180
- alpha = alpha or self.hparams["alpha"]
181
- sample_p = alpha * image_p + (1 - alpha) * text_p
182
-
183
- # get the scores
184
- scores = sample_p[0].cpu().tolist()
185
-
186
- return vocabulary, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/retrieval.py DELETED
@@ -1,30 +0,0 @@
1
- from pathlib import Path
2
- from typing import Optional
3
-
4
- import numpy as np
5
- import pyarrow as pa
6
-
7
-
8
- class ArrowMetadataProvider:
9
- """The arrow metadata provider provides metadata from contiguous ids using arrow.
10
-
11
- Code taken from: https://github.dev/rom1504/clip-retrieval
12
- """
13
-
14
- def __init__(self, arrow_folder: Path):
15
- arrow_files = [str(a) for a in sorted(arrow_folder.glob("**/*")) if a.is_file()]
16
- self.table = pa.concat_tables(
17
- [
18
- pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
19
- for arrow_file in arrow_files
20
- ]
21
- )
22
-
23
- def get(self, ids: np.ndarray, cols: Optional[list] = None):
24
- """Implement the get method from the arrow metadata provide, get metadata from ids."""
25
- if cols is None:
26
- cols = self.table.schema.names
27
- else:
28
- cols = list(set(self.table.schema.names) & set(cols))
29
- t = pa.concat_tables([self.table[i:j] for i, j in zip(ids, ids + 1)])
30
- return t.select(cols).to_pandas().to_dict("records")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/transforms.py DELETED
@@ -1,506 +0,0 @@
1
- import re
2
- from abc import ABC, abstractmethod
3
- from typing import Any, Optional, Union, cast
4
-
5
- import inflect
6
- import nltk
7
- import numpy as np
8
- import PIL.Image
9
- import torch
10
- import torchvision.transforms as T
11
- import torchvision.transforms.functional as F
12
- from flair.data import Sentence
13
- from flair.models import SequenceTagger
14
-
15
- __all__ = [
16
- "DynamicResize",
17
- "DropFileExtensions",
18
- "DropNonAlpha",
19
- "DropShortWords",
20
- "DropSpecialCharacters",
21
- "DropTokens",
22
- "DropURLs",
23
- "DropWords",
24
- "FilterPOS",
25
- "FrequencyMinWordCount",
26
- "FrequencyTopK",
27
- "ReplaceSeparators",
28
- "ToRGBTensor",
29
- "ToLowercase",
30
- "ToSingular",
31
- ]
32
-
33
-
34
- class BaseTextTransform(ABC):
35
- """Base class for string transforms."""
36
-
37
- @abstractmethod
38
- def __call__(self, text: str):
39
- raise NotImplementedError
40
-
41
- def __repr__(self) -> str:
42
- return f"{self.__class__.__name__}()"
43
-
44
-
45
- class DynamicResize(T.Resize):
46
- """Resize the input PIL Image to the given size.
47
-
48
- Extends the torchvision Resize transform to dynamically evaluate the second dimension of the
49
- output size based on the aspect ratio of the first input image.
50
- """
51
-
52
- def forward(self, img):
53
- if isinstance(self.size, int):
54
- _, h, w = F.get_dimensions(img)
55
- aspect_ratio = w / h
56
- side = self.size
57
-
58
- if aspect_ratio < 1.0:
59
- self.size = int(side / aspect_ratio), side
60
- else:
61
- self.size = side, int(side * aspect_ratio)
62
-
63
- return super().forward(img)
64
-
65
-
66
- class DropFileExtensions(BaseTextTransform):
67
- """Remove file extensions from the input text."""
68
-
69
- def __call__(self, text: str):
70
- """
71
- Args:
72
- text (str): Text to remove file extensions from.
73
- """
74
- text = re.sub(r"\.\w+", "", text)
75
-
76
- return text
77
-
78
-
79
- class DropNonAlpha(BaseTextTransform):
80
- """Remove non-alpha words from the input text."""
81
-
82
- def __call__(self, text: str):
83
- """
84
- Args:
85
- text (str): Text to remove non-alpha words from.
86
- """
87
- text = re.sub(r"[^a-zA-Z\s]", "", text)
88
-
89
- return text
90
-
91
-
92
- class DropShortWords(BaseTextTransform):
93
- """Remove short words from the input text.
94
-
95
- Args:
96
- min_length (int): Minimum length of words to keep.
97
- """
98
-
99
- def __init__(self, min_length) -> None:
100
- super().__init__()
101
- self.min_length = min_length
102
-
103
- def __call__(self, text: str):
104
- """
105
- Args:
106
- text (str): Text to remove short words from.
107
- """
108
- text = " ".join([word for word in text.split() if len(word) >= self.min_length])
109
-
110
- return text
111
-
112
- def __repr__(self) -> str:
113
- return f"{self.__class__.__name__}(min_length={self.min_length})"
114
-
115
-
116
- class DropSpecialCharacters(BaseTextTransform):
117
- """Remove special characters from the input text.
118
-
119
- Special characters are defined as any character that is not a word character, whitespace,
120
- hyphen, period, apostrophe, or ampersand.
121
- """
122
-
123
- def __call__(self, text: str):
124
- """
125
- Args:
126
- text (str): Text to remove special characters from.
127
- """
128
- text = re.sub(r"[^\w\s\-\.\'\&]", "", text)
129
-
130
- return text
131
-
132
-
133
- class DropTokens(BaseTextTransform):
134
- """Remove tokens from the input text.
135
-
136
- Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
137
- """
138
-
139
- def __call__(self, text: str):
140
- """
141
- Args:
142
- text (str): Text to remove tokens from.
143
- """
144
- text = re.sub(r"<[^>]+>", "", text)
145
-
146
- return text
147
-
148
-
149
- class DropURLs(BaseTextTransform):
150
- """Remove URLs from the input text."""
151
-
152
- def __call__(self, text: str):
153
- """
154
- Args:
155
- text (str): Text to remove URLs from.
156
- """
157
- text = re.sub(r"http\S+", "", text)
158
-
159
- return text
160
-
161
-
162
- class DropWords(BaseTextTransform):
163
- """Remove words from the input text.
164
-
165
- It is case-insensitive and supports singular and plural forms of the words.
166
- """
167
-
168
- def __init__(self, words: list[str]) -> None:
169
- super().__init__()
170
- self.words = words
171
- self.pattern = r"\b(?:{})\b".format("|".join(words))
172
-
173
- def __call__(self, text: str):
174
- """
175
- Args:
176
- text (str): Text to remove words from.
177
- """
178
- text = re.sub(self.pattern, "", text, flags=re.IGNORECASE)
179
-
180
- return text
181
-
182
- def __repr__(self) -> str:
183
- return f"{self.__class__.__name__}(pattern={self.pattern})"
184
-
185
-
186
- class FilterPOS(BaseTextTransform):
187
- """Filter words by POS tags.
188
-
189
- Args:
190
- tags (list): List of POS tags to remove.
191
- engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk".
192
- keep_compound_nouns (bool): Whether to keep composed words. Defaults to True.
193
- """
194
-
195
- def __init__(self, tags: list, engine: str = "nltk", keep_compound_nouns: bool = True) -> None:
196
- super().__init__()
197
- self.tags = tags
198
- self.engine = engine
199
- self.keep_compound_nouns = keep_compound_nouns
200
-
201
- if engine == "nltk":
202
- nltk.download("averaged_perceptron_tagger", quiet=True)
203
- nltk.download("punkt", quiet=True)
204
- self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
205
- elif engine == "flair":
206
- self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
207
-
208
- def __call__(self, text: str):
209
- """
210
- Args:
211
- text (str): Text to remove words with specific POS tags from.
212
- """
213
- if self.engine == "nltk":
214
- word_tags = self.tagger(text)
215
- text = " ".join([word for word, tag in word_tags if tag not in self.tags])
216
- elif self.engine == "flair":
217
- sentence = Sentence(text)
218
- self.tagger(sentence)
219
- text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
220
-
221
- if self.keep_compound_nouns:
222
- compound_nouns = []
223
-
224
- if self.engine == "nltk":
225
- for i in range(len(word_tags) - 1):
226
- if word_tags[i][1] == "NN" and word_tags[i + 1][1] == "NN":
227
- # if they are the same word, skip
228
- if word_tags[i][0] == word_tags[i + 1][0]:
229
- continue
230
-
231
- compound_noun = word_tags[i][0] + "_" + word_tags[i + 1][0]
232
- compound_nouns.append(compound_noun)
233
- elif self.engine == "flair":
234
- for i in range(len(sentence.tokens) - 1):
235
- if sentence.tokens[i].tag == "NN" and sentence.tokens[i + 1].tag == "NN":
236
- # if they are the same word, skip
237
- if sentence.tokens[i].text == sentence.tokens[i + 1].text:
238
- continue
239
-
240
- compound_noun = sentence.tokens[i].text + "_" + sentence.tokens[i + 1].text
241
- compound_nouns.append(compound_noun)
242
-
243
- text = " ".join([text, " ".join(compound_nouns)])
244
-
245
- return text
246
-
247
- def __repr__(self) -> str:
248
- return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})"
249
-
250
-
251
- class FrequencyMinWordCount(BaseTextTransform):
252
- """Keep only words that occur more than a minimum number of times in the input text.
253
-
254
- If the threshold is too strong and no words pass the threshold, the threshold is reduced to
255
- the most frequent word.
256
-
257
- Args:
258
- min_count (int): Minimum number of occurrences of a word to keep.
259
- """
260
-
261
- def __init__(self, min_count) -> None:
262
- super().__init__()
263
- self.min_count = min_count
264
-
265
- def __call__(self, text: str):
266
- """
267
- Args:
268
- text (str): Text to remove infrequent words from.
269
- """
270
- if self.min_count <= 1:
271
- return text
272
-
273
- words = text.split()
274
- word_counts = {word: words.count(word) for word in words}
275
-
276
- # if nothing passes the threshold, reduce the threshold to the most frequent word
277
- max_word_count = max(word_counts.values() or [0])
278
- min_count = max_word_count if self.min_count > max_word_count else self.min_count
279
-
280
- text = " ".join([word for word in words if word_counts[word] >= min_count])
281
-
282
- return text
283
-
284
- def __repr__(self) -> str:
285
- return f"{self.__class__.__name__}(min_count={self.min_count})"
286
-
287
-
288
- class FrequencyTopK(BaseTextTransform):
289
- """Keep only the top k most frequent words in the input text.
290
-
291
- In case of a tie, all words with the same count as the last word are kept.
292
-
293
- Args:
294
- top_k (int): Number of top words to keep.
295
- """
296
-
297
- def __init__(self, top_k: int) -> None:
298
- super().__init__()
299
- self.top_k = top_k
300
-
301
- def __call__(self, text: str):
302
- """
303
- Args:
304
- text (str): Text to remove infrequent words from.
305
- """
306
- if self.top_k < 1:
307
- return text
308
-
309
- words = text.split()
310
- word_counts = {word: words.count(word) for word in words}
311
- top_words = sorted(word_counts, key=word_counts.get, reverse=True)
312
-
313
- # in case of a tie, keep all words with the same count
314
- top_words = top_words[: self.top_k]
315
- top_words = [word for word in top_words if word_counts[word] == word_counts[top_words[-1]]]
316
-
317
- text = " ".join([word for word in words if word in top_words])
318
-
319
- return text
320
-
321
- def __repr__(self) -> str:
322
- return f"{self.__class__.__name__}(top_k={self.top_k})"
323
-
324
-
325
- class ReplaceSeparators(BaseTextTransform):
326
- """Replace underscores and dashes with spaces."""
327
-
328
- def __call__(self, text: str):
329
- """
330
- Args:
331
- text (str): Text to replace separators in.
332
- """
333
- text = re.sub(r"[_\-]", " ", text)
334
-
335
- return text
336
-
337
- def __repr__(self) -> str:
338
- return f"{self.__class__.__name__}()"
339
-
340
-
341
- class RemoveDuplicates(BaseTextTransform):
342
- """Remove duplicate words from the input text."""
343
-
344
- def __call__(self, text: str):
345
- """
346
- Args:
347
- text (str): Text to remove duplicate words from.
348
- """
349
- text = " ".join(list(set(text.split())))
350
-
351
- return text
352
-
353
-
354
- class TextCompose:
355
- """Compose several transforms together.
356
-
357
- It differs from the torchvision.transforms.Compose class in that it applies the transforms to
358
- a string instead of a PIL Image or Tensor. In addition, it automatically join the list of
359
- input strings into a single string and splits the output string into a list of words.
360
-
361
- Args:
362
- transforms (list): List of transforms to compose.
363
- """
364
-
365
- def __init__(self, transforms: list[BaseTextTransform]) -> None:
366
- self.transforms = transforms
367
-
368
- def __call__(self, text: Union[str, list[str]]) -> Any:
369
- if isinstance(text, list):
370
- text = " ".join(text)
371
-
372
- for t in self.transforms:
373
- text = t(text)
374
- return text.split()
375
-
376
- def __repr__(self) -> str:
377
- format_string = self.__class__.__name__ + "("
378
- for t in self.transforms:
379
- format_string += "\n"
380
- format_string += f" {t}"
381
- format_string += "\n)"
382
- return format_string
383
-
384
-
385
- class ToRGBTensor(T.ToTensor):
386
- """Convert a `PIL Image` or `numpy.ndarray` to tensor.
387
-
388
- Compared with the torchvision `ToTensor` transform, it converts images with a single channel to
389
- RGB images. In addition, the conversion to tensor is done only if the input is not already a
390
- tensor.
391
- """
392
-
393
- def __call__(self, pic: Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
394
- """
395
- Args:
396
- pic (PIL Image | numpy.ndarray | torch.Tensor): Image to be converted to tensor.
397
- """
398
- img = pic if isinstance(pic, torch.Tensor) else F.to_tensor(pic)
399
- img = cast(torch.Tensor, img)
400
-
401
- if img.shape[0] == 1:
402
- img = img.repeat(3, 1, 1)
403
-
404
- return img
405
-
406
- def __repr__(self) -> str:
407
- return f"{self.__class__.__name__}()"
408
-
409
-
410
- class ToLowercase(BaseTextTransform):
411
- """Convert text to lowercase."""
412
-
413
- def __call__(self, text: str):
414
- """
415
- Args:
416
- text (str): Text to convert to lowercase.
417
- """
418
- text = text.lower()
419
-
420
- return text
421
-
422
-
423
- class ToSingular(BaseTextTransform):
424
- """Convert plural words to singular form."""
425
-
426
- def __init__(self) -> None:
427
- super().__init__()
428
- self.transform = inflect.engine().singular_noun
429
-
430
- def __call__(self, text: str):
431
- """
432
- Args:
433
- text (str): Text to convert to singular form.
434
- """
435
- words = text.split()
436
- for i, word in enumerate(words):
437
- if not word.endswith("s"):
438
- continue
439
-
440
- if word[-2:] in ["ss", "us", "is"]:
441
- continue
442
-
443
- if word[-3:] in ["ies", "oes"]:
444
- continue
445
-
446
- words[i] = self.transform(word) or word
447
-
448
- text = " ".join(words)
449
-
450
- return text
451
-
452
- def __repr__(self) -> str:
453
- return f"{self.__class__.__name__}()"
454
-
455
-
456
- def default_preprocess(size: Optional[int] = None) -> T.Compose:
457
- """Preprocess input images with preprocessing transforms.
458
-
459
- Args:
460
- size (int): Size to resize image to.
461
- """
462
- transforms = []
463
- if size is not None:
464
- transforms.append(DynamicResize(size, interpolation=T.InterpolationMode.BICUBIC))
465
- transforms.append(ToRGBTensor())
466
- transforms = T.Compose(transforms)
467
-
468
- return transforms
469
-
470
-
471
- def default_vocabulary_transforms() -> TextCompose:
472
- """Preprocess input text with preprocessing transforms."""
473
- words_to_drop = [
474
- "image",
475
- "photo",
476
- "picture",
477
- "thumbnail",
478
- "logo",
479
- "symbol",
480
- "clipart",
481
- "portrait",
482
- "painting",
483
- "illustration",
484
- "icon",
485
- "profile",
486
- ]
487
- pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"]
488
-
489
- transforms = []
490
- transforms.append(DropTokens())
491
- transforms.append(DropURLs())
492
- transforms.append(DropSpecialCharacters())
493
- transforms.append(DropFileExtensions())
494
- transforms.append(ReplaceSeparators())
495
- transforms.append(DropShortWords(min_length=3))
496
- transforms.append(DropNonAlpha())
497
- transforms.append(ToLowercase())
498
- transforms.append(ToSingular())
499
- transforms.append(DropWords(words=words_to_drop))
500
- transforms.append(FrequencyMinWordCount(min_count=2))
501
- transforms.append(FilterPOS(tags=pos_tags, engine="flair", keep_compound_nouns=False))
502
- transforms.append(RemoveDuplicates())
503
-
504
- transforms = TextCompose(transforms)
505
-
506
- return transforms