jcarnero commited on
Commit
5d460ce
·
1 Parent(s): 96998e8

tests for transformation. centercrop transform failing.

Browse files
.vscode/settings.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "python.testing.pytestArgs": [
3
+ "tests"
4
+ ],
5
+ "python.testing.unittestEnabled": false,
6
+ "python.testing.pytestEnabled": true,
7
+ "python.analysis.inlayHints.pytestParameters": true
8
+ }
deployment/transforms.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ import torchvision.transforms as tvtfms
3
+
4
+
5
+ def CenterCropPad(size: tuple[Literal[460], Literal[460]]):
6
+ return tvtfms.CenterCrop(size)
tests/__init__.py ADDED
File without changes
tests/test_transforms.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from pathlib import Path
4
+ from typing import List
5
+ import numpy as np
6
+ from PIL import Image
7
+ from fastai.vision.data import PILImage
8
+ import fastai.vision.augment as fastai_aug
9
+
10
+ from deployment.transforms import CenterCropPad
11
+
12
+ DATA_PATH = "data/kaggle/200-bird-species-with-11788-images"
13
+
14
+
15
+ def get_birds_images(path: Path) -> List[str]:
16
+ with open(path / "images.txt", "r") as file:
17
+ lines = [
18
+ path.resolve() / "images" / line.strip().split()[1]
19
+ for line in file.readlines()
20
+ ]
21
+ return lines
22
+
23
+
24
+ class TestTransforms:
25
+ im_idx = 0
26
+
27
+ @pytest.fixture
28
+ def img_paths(self) -> List[str]:
29
+ path = Path(DATA_PATH) / "CUB_200_2011"
30
+ return get_birds_images(path.resolve())
31
+
32
+ @pytest.fixture
33
+ def im_fastai(self, img_paths) -> PILImage:
34
+ fname = img_paths[self.im_idx]
35
+ return PILImage.create(fname)
36
+
37
+ @pytest.fixture
38
+ def im_pil(self, img_paths) -> Image:
39
+ fname = img_paths[self.im_idx]
40
+ return Image.open(fname)
41
+
42
+ def testImageFastaiEqualsPillow(self, im_fastai, im_pil):
43
+ assert (np.array(im_fastai) == np.array(im_pil)).all()
44
+
45
+ def testCropPadFastaiEqualsTorch(self, im_fastai, im_pil):
46
+ crop_fastai = fastai_aug.CropPad((460, 460))
47
+ crop_torch = CenterCropPad((460, 460))
48
+
49
+ assert (np.array(crop_fastai(im_fastai)) == np.array(crop_torch(im_pil))).all()