artelabsuper commited on
Commit
459c031
1 Parent(s): 7f268fe

add local normalization

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. test.py +1 -0
app.py CHANGED
@@ -35,6 +35,7 @@ def predict(input_image, model_name):
35
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
36
  # transform image to torch and do preprocessing
37
  torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
 
38
  # model predict
39
  with torch.no_grad():
40
  output = generators[MODELS_TYPE.index(model_name)](torch_img)
35
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
36
  # transform image to torch and do preprocessing
37
  torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
38
+ torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
39
  # model predict
40
  with torch.no_grad():
41
  output = generators[MODELS_TYPE.index(model_name)](torch_img)
test.py CHANGED
@@ -40,6 +40,7 @@ preprocess = transforms.Compose([
40
  ])
41
  input_img = Image.open('demo_imgs/fake.jpg')
42
  torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
 
43
  with torch.no_grad():
44
  output = generator(torch_img)
45
  sr, sr_dem_selected = output[0], output[1]
40
  ])
41
  input_img = Image.open('demo_imgs/fake.jpg')
42
  torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0).to(DEVICE)
43
+ torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
44
  with torch.no_grad():
45
  output = generator(torch_img)
46
  sr, sr_dem_selected = output[0], output[1]