Thastp commited on
Commit
c4cc546
·
verified ·
1 Parent(s): f4b97bc

Upload processor

Browse files
Files changed (1) hide show
  1. image_processing_efficientnet.py +47 -8
image_processing_efficientnet.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
2
  from timm import create_model
3
  from timm.data import resolve_data_config
@@ -6,19 +10,54 @@ from timm.data.transforms_factory import create_transform
6
  class EfficientNetImageProcessor(BaseImageProcessor):
7
  model_input_names = ["pixel_values"]
8
 
9
- def __init__(self,
10
- model_name: str,
11
- **kwargs
12
- ):
13
- super().__init__(**kwargs)
14
-
15
  self.model_name = model_name
16
  self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
18
 
19
- def preprocess(self, image):
20
  transforms = create_transform(**self.config)
21
- data = {'pixel_values': transforms(image).unsqueeze(0)}
 
 
 
 
 
22
  return BatchFeature(data=data)
23
 
24
  __all__ = [
 
1
+ from PIL import Image
2
+ from torch import Tensor, stack
3
+ from typing import Union, List
4
+
5
  from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
6
  from timm import create_model
7
  from timm.data import resolve_data_config
 
10
  class EfficientNetImageProcessor(BaseImageProcessor):
11
  model_input_names = ["pixel_values"]
12
 
13
+ def __init__(
14
+ self,
15
+ model_name: str,
16
+ **kwargs,
17
+ ):
 
18
  self.model_name = model_name
19
  self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
20
+ super().__init__(**kwargs)
21
+
22
+ def preprocess(
23
+ self,
24
+ images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
25
+ ) -> BatchFeature:
26
+ """
27
+ Preprocesses input images by applying transformations and returning them as a BatchFeature.
28
+
29
+ Parameters
30
+ ----------
31
+ images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
32
+ A single image or a list of images in one of the accepted formats.
33
+
34
+ Returns
35
+ -------
36
+ BatchFeature
37
+ A batch of transformed images
38
+ """
39
+ images = [images] if not isinstance(images, list) else images
40
+
41
+ # TEST: empty list
42
+ if len(images) == 0:
43
+ raise ValueError("Received an empty list of images")
44
 
45
+ # TEST: validate input type
46
+ test_image = images[0]
47
+ if not isinstance(images[0], (Image.Image, Tensor)):
48
+ raise TypeError(
49
+ f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
50
+ f"but got {type(test_image).__name__} instead."
51
+ )
52
 
53
+ # Apply transformations
54
  transforms = create_transform(**self.config)
55
+ transformed_images = [transforms(image) for image in images]
56
+
57
+ # Convert to batch tensor
58
+ transformed_image_tensors = stack(transformed_images)
59
+
60
+ data = {'pixel_values': transformed_image_tensors}
61
  return BatchFeature(data=data)
62
 
63
  __all__ = [