zamborg commited on
Commit
09da12b
1 Parent(s): b0943d9

show resize code update

Browse files
Files changed (1) hide show
  1. model.py +7 -8
model.py CHANGED
@@ -26,13 +26,7 @@ class ImageLoader():
26
  torchvision.transforms.Resize(256),
27
  torchvision.transforms.CenterCrop(224),
28
  torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
29
- self.show_manip = torchvision.transforms.Compose(
30
- [
31
- torchvision.transforms.ToTensor(),
32
- torchvision.transforms.Resize(500, max_size=500),
33
- torchvision.transforms.ToPILImage()
34
- ]
35
- )
36
 
37
  def load(self, im_path):
38
  im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
@@ -51,7 +45,12 @@ class ImageLoader():
51
  return text.lower()
52
 
53
  def show_resize(self, image):
54
- return self.show_manip(image)
 
 
 
 
 
55
 
56
 
57
  class VirTexModel():
 
26
  torchvision.transforms.Resize(256),
27
  torchvision.transforms.CenterCrop(224),
28
  torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
29
+ self.show_size=500
 
 
 
 
 
 
30
 
31
  def load(self, im_path):
32
  im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
 
45
  return text.lower()
46
 
47
  def show_resize(self, image):
48
+ # ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
49
+ image = torchvision.transforms.functional.to_tensor(image)
50
+ x,y = image.shape[-2:]
51
+ ratio = float(self.show_size/max((x,y)))
52
+ image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
53
+ return torch.transforms.functional.to_pil_image(image)
54
 
55
 
56
  class VirTexModel():