altndrr commited on
Commit
101e533
1 Parent(s): 0483742

Update return types

Browse files
Files changed (2) hide show
  1. modeling_cased.py +3 -4
  2. transforms_cased.py +20 -16
modeling_cased.py CHANGED
@@ -197,8 +197,7 @@ class CaSEDModel(PreTrainedModel):
197
 
198
  return vocabularies
199
 
200
- @torch.no_grad()
201
- def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor():
202
  """Forward pass.
203
 
204
  Args:
@@ -248,8 +247,8 @@ class CaSEDModel(PreTrainedModel):
248
  vocabularies.append(vocabulary)
249
 
250
  # get the scores
251
- samples_p = torch.stack(samples_p, dim=0)
252
- scores = sample_p.cpu()
253
 
254
  # define the results
255
  results = {"vocabularies": vocabularies, "scores": scores}
 
197
 
198
  return vocabularies
199
 
200
+ def forward(self, images: dict, alpha: Optional[float] = None) -> torch.Tensor:
 
201
  """Forward pass.
202
 
203
  Args:
 
247
  vocabularies.append(vocabulary)
248
 
249
  # get the scores
250
+ samples_p = torch.cat(samples_p, dim=0)
251
+ scores = samples_p.cpu()
252
 
253
  # define the results
254
  results = {"vocabularies": vocabularies, "scores": scores}
transforms_cased.py CHANGED
@@ -28,7 +28,7 @@ class BaseTextTransform(ABC):
28
  """Base class for string transforms."""
29
 
30
  @abstractmethod
31
- def __call__(self, text: str):
32
  raise NotImplementedError
33
 
34
  def __repr__(self) -> str:
@@ -38,7 +38,7 @@ class BaseTextTransform(ABC):
38
  class DropFileExtensions(BaseTextTransform):
39
  """Remove file extensions from the input text."""
40
 
41
- def __call__(self, text: str):
42
  """
43
  Args:
44
  text (str): Text to remove file extensions from.
@@ -51,7 +51,7 @@ class DropFileExtensions(BaseTextTransform):
51
  class DropNonAlpha(BaseTextTransform):
52
  """Remove non-alpha words from the input text."""
53
 
54
- def __call__(self, text: str):
55
  """
56
  Args:
57
  text (str): Text to remove non-alpha words from.
@@ -72,7 +72,7 @@ class DropShortWords(BaseTextTransform):
72
  super().__init__()
73
  self.min_length = min_length
74
 
75
- def __call__(self, text: str):
76
  """
77
  Args:
78
  text (str): Text to remove short words from.
@@ -92,7 +92,7 @@ class DropSpecialCharacters(BaseTextTransform):
92
  hyphen, period, apostrophe, or ampersand.
93
  """
94
 
95
- def __call__(self, text: str):
96
  """
97
  Args:
98
  text (str): Text to remove special characters from.
@@ -108,7 +108,7 @@ class DropTokens(BaseTextTransform):
108
  Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
109
  """
110
 
111
- def __call__(self, text: str):
112
  """
113
  Args:
114
  text (str): Text to remove tokens from.
@@ -121,7 +121,7 @@ class DropTokens(BaseTextTransform):
121
  class DropURLs(BaseTextTransform):
122
  """Remove URLs from the input text."""
123
 
124
- def __call__(self, text: str):
125
  """
126
  Args:
127
  text (str): Text to remove URLs from.
@@ -142,7 +142,7 @@ class DropWords(BaseTextTransform):
142
  self.words = words
143
  self.pattern = r"\b(?:{})\b".format("|".join(words))
144
 
145
- def __call__(self, text: str):
146
  """
147
  Args:
148
  text (str): Text to remove words from.
@@ -177,7 +177,7 @@ class FilterPOS(BaseTextTransform):
177
  elif engine == "flair":
178
  self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
179
 
180
- def __call__(self, text: str):
181
  """
182
  Args:
183
  text (str): Text to remove words with specific POS tags from.
@@ -234,7 +234,7 @@ class FrequencyMinWordCount(BaseTextTransform):
234
  super().__init__()
235
  self.min_count = min_count
236
 
237
- def __call__(self, text: str):
238
  """
239
  Args:
240
  text (str): Text to remove infrequent words from.
@@ -270,7 +270,7 @@ class FrequencyTopK(BaseTextTransform):
270
  super().__init__()
271
  self.top_k = top_k
272
 
273
- def __call__(self, text: str):
274
  """
275
  Args:
276
  text (str): Text to remove infrequent words from.
@@ -297,7 +297,7 @@ class FrequencyTopK(BaseTextTransform):
297
  class ReplaceSeparators(BaseTextTransform):
298
  """Replace underscores and dashes with spaces."""
299
 
300
- def __call__(self, text: str):
301
  """
302
  Args:
303
  text (str): Text to replace separators in.
@@ -313,7 +313,7 @@ class ReplaceSeparators(BaseTextTransform):
313
  class RemoveDuplicates(BaseTextTransform):
314
  """Remove duplicate words from the input text."""
315
 
316
- def __call__(self, text: str):
317
  """
318
  Args:
319
  text (str): Text to remove duplicate words from.
@@ -337,7 +337,11 @@ class TextCompose:
337
  def __init__(self, transforms: list[BaseTextTransform]) -> None:
338
  self.transforms = transforms
339
 
340
- def __call__(self, text: Union[str, list[str]]) -> Any:
 
 
 
 
341
  if isinstance(text, list):
342
  text = " ".join(text)
343
 
@@ -357,7 +361,7 @@ class TextCompose:
357
  class ToLowercase(BaseTextTransform):
358
  """Convert text to lowercase."""
359
 
360
- def __call__(self, text: str):
361
  """
362
  Args:
363
  text (str): Text to convert to lowercase.
@@ -374,7 +378,7 @@ class ToSingular(BaseTextTransform):
374
  super().__init__()
375
  self.transform = inflect.engine().singular_noun
376
 
377
- def __call__(self, text: str):
378
  """
379
  Args:
380
  text (str): Text to convert to singular form.
 
28
  """Base class for string transforms."""
29
 
30
  @abstractmethod
31
+ def __call__(self, text: str) -> str:
32
  raise NotImplementedError
33
 
34
  def __repr__(self) -> str:
 
38
  class DropFileExtensions(BaseTextTransform):
39
  """Remove file extensions from the input text."""
40
 
41
+ def __call__(self, text: str) -> str:
42
  """
43
  Args:
44
  text (str): Text to remove file extensions from.
 
51
  class DropNonAlpha(BaseTextTransform):
52
  """Remove non-alpha words from the input text."""
53
 
54
+ def __call__(self, text: str) -> str:
55
  """
56
  Args:
57
  text (str): Text to remove non-alpha words from.
 
72
  super().__init__()
73
  self.min_length = min_length
74
 
75
+ def __call__(self, text: str) -> str:
76
  """
77
  Args:
78
  text (str): Text to remove short words from.
 
92
  hyphen, period, apostrophe, or ampersand.
93
  """
94
 
95
+ def __call__(self, text: str) -> str:
96
  """
97
  Args:
98
  text (str): Text to remove special characters from.
 
108
  Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
109
  """
110
 
111
+ def __call__(self, text: str) -> str:
112
  """
113
  Args:
114
  text (str): Text to remove tokens from.
 
121
  class DropURLs(BaseTextTransform):
122
  """Remove URLs from the input text."""
123
 
124
+ def __call__(self, text: str) -> str:
125
  """
126
  Args:
127
  text (str): Text to remove URLs from.
 
142
  self.words = words
143
  self.pattern = r"\b(?:{})\b".format("|".join(words))
144
 
145
+ def __call__(self, text: str) -> str:
146
  """
147
  Args:
148
  text (str): Text to remove words from.
 
177
  elif engine == "flair":
178
  self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
179
 
180
+ def __call__(self, text: str) -> str:
181
  """
182
  Args:
183
  text (str): Text to remove words with specific POS tags from.
 
234
  super().__init__()
235
  self.min_count = min_count
236
 
237
+ def __call__(self, text: str) -> str:
238
  """
239
  Args:
240
  text (str): Text to remove infrequent words from.
 
270
  super().__init__()
271
  self.top_k = top_k
272
 
273
+ def __call__(self, text: str) -> str:
274
  """
275
  Args:
276
  text (str): Text to remove infrequent words from.
 
297
  class ReplaceSeparators(BaseTextTransform):
298
  """Replace underscores and dashes with spaces."""
299
 
300
+ def __call__(self, text: str) -> str:
301
  """
302
  Args:
303
  text (str): Text to replace separators in.
 
313
  class RemoveDuplicates(BaseTextTransform):
314
  """Remove duplicate words from the input text."""
315
 
316
+ def __call__(self, text: str) -> str:
317
  """
318
  Args:
319
  text (str): Text to remove duplicate words from.
 
337
  def __init__(self, transforms: list[BaseTextTransform]) -> None:
338
  self.transforms = transforms
339
 
340
+ def __call__(self, text: Union[str, list[str]]) -> list[str]:
341
+ """
342
+ Args:
343
+ text (Union[str, list[str]]): Text to transform.
344
+ """
345
  if isinstance(text, list):
346
  text = " ".join(text)
347
 
 
361
  class ToLowercase(BaseTextTransform):
362
  """Convert text to lowercase."""
363
 
364
+ def __call__(self, text: str) -> str:
365
  """
366
  Args:
367
  text (str): Text to convert to lowercase.
 
378
  super().__init__()
379
  self.transform = inflect.engine().singular_noun
380
 
381
+ def __call__(self, text: str) -> str:
382
  """
383
  Args:
384
  text (str): Text to convert to singular form.