Optional texts in processor.
Browse files- processor.py +10 -9
processor.py
CHANGED
@@ -38,7 +38,7 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
38 |
img = F.normalize(img, self.image_mean, self.image_std)
|
39 |
return img
|
40 |
|
41 |
-
def __call__(self, images, texts):
|
42 |
"""
|
43 |
Parameters
|
44 |
----------
|
@@ -55,14 +55,15 @@ class CondViTProcessor(ImageProcessingMixin):
|
|
55 |
texts : Union[str, List[str]]
|
56 |
"""
|
57 |
# Single Image
|
|
|
58 |
if isinstance(images, Image.Image):
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
)
|
62 |
|
63 |
-
|
64 |
-
data=
|
65 |
-
|
66 |
-
|
67 |
-
}
|
68 |
-
)
|
|
|
38 |
img = F.normalize(img, self.image_mean, self.image_std)
|
39 |
return img
|
40 |
|
41 |
+
def __call__(self, images, texts=None):
|
42 |
"""
|
43 |
Parameters
|
44 |
----------
|
|
|
55 |
texts : Union[str, List[str]]
|
56 |
"""
|
57 |
# Single Image
|
58 |
+
data = {}
|
59 |
if isinstance(images, Image.Image):
|
60 |
+
data["pixel_values"] = self.process_img(images)
|
61 |
+
else:
|
62 |
+
data["pixel_values"] = torch.stack(
|
63 |
+
[self.process_img(img) for img in images]
|
64 |
)
|
65 |
|
66 |
+
if texts is not None:
|
67 |
+
data["texts"] = texts
|
68 |
+
|
69 |
+
return BatchFeature(data=data)
|
|
|
|