Spaces:
Runtime error
Runtime error
show resize code update
Browse files
model.py
CHANGED
@@ -26,13 +26,7 @@ class ImageLoader():
|
|
26 |
torchvision.transforms.Resize(256),
|
27 |
torchvision.transforms.CenterCrop(224),
|
28 |
torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
|
29 |
-
self.
|
30 |
-
[
|
31 |
-
torchvision.transforms.ToTensor(),
|
32 |
-
torchvision.transforms.Resize(500, max_size=500),
|
33 |
-
torchvision.transforms.ToPILImage()
|
34 |
-
]
|
35 |
-
)
|
36 |
|
37 |
def load(self, im_path):
|
38 |
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
|
@@ -51,7 +45,12 @@ class ImageLoader():
|
|
51 |
return text.lower()
|
52 |
|
53 |
def show_resize(self, image):
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
class VirTexModel():
|
|
|
26 |
torchvision.transforms.Resize(256),
|
27 |
torchvision.transforms.CenterCrop(224),
|
28 |
torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
|
29 |
+
self.show_size=500
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def load(self, im_path):
|
32 |
im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
|
|
|
45 |
return text.lower()
|
46 |
|
47 |
def show_resize(self, image):
|
48 |
+
# ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
|
49 |
+
image = torchvision.transforms.functional.to_tensor(image)
|
50 |
+
x,y = image.shape[-2:]
|
51 |
+
ratio = float(self.show_size/max((x,y)))
|
52 |
+
image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
|
53 |
+
return torch.transforms.functional.to_pil_image(image)
|
54 |
|
55 |
|
56 |
class VirTexModel():
|