zamborg commited on
Commit
f307fe5
1 Parent(s): 09c5885

fixed normalization

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +8 -3
app.py CHANGED
@@ -54,7 +54,7 @@ else:
54
 
55
  image_dict = imageLoader.transform(image)
56
 
57
- image = imageLoader.to_image(image_dict["image"].squeeze(0))
58
 
59
  show = st.image(image)
60
  show.image(image, "Your Image")
54
 
55
  image_dict = imageLoader.transform(image)
56
 
57
+ # image = imageLoader.to_image(image_dict["image"].squeeze(0))
58
 
59
  show = st.image(image)
60
  show.image(image, "Your Image")
model.py CHANGED
@@ -11,7 +11,7 @@ import torchvision
11
  import wordsegment as ws
12
 
13
  from virtex.config import Config
14
- from virtex.factories import TokenizerFactory, PretrainingModelFactory
15
  from virtex.utils.checkpointing import CheckpointManager
16
 
17
  CONFIG_PATH = "config.yaml"
@@ -21,12 +21,17 @@ SAMPLES_PATH = "./samples/*.jpg"
21
 
22
  class ImageLoader():
23
  def __init__(self):
24
- self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
25
- torchvision.transforms.CenterCrop(224),
 
26
  torchvision.transforms.ToTensor()])
27
  def load(self, im_path):
28
  im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
29
  return {"image": im}
 
 
 
 
30
  def transform(self, image):
31
  im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
32
  return {"image": im}
11
  import wordsegment as ws
12
 
13
  from virtex.config import Config
14
+ from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
15
  from virtex.utils.checkpointing import CheckpointManager
16
 
17
  CONFIG_PATH = "config.yaml"
21
 
22
  class ImageLoader():
23
  def __init__(self):
24
+ self.transformer = torchvision.transforms.Compose([ImageTransformsFactory.create("smallest_resize"),
25
+ ImageTransformsFactory.create("center_crop"),
26
+ ImageTransformsFactory.create("normalize"),
27
  torchvision.transforms.ToTensor()])
28
  def load(self, im_path):
29
  im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
30
  return {"image": im}
31
+
32
+ def raw_load(self, im_path):
33
+ im = torch.FloatTensor(Image.open(im_path)).unsqueeze(0)
34
+ return {"image": im}
35
  def transform(self, image):
36
  im = torch.FloatTensor(self.transformer(image)).unsqueeze(0)
37
  return {"image": im}