Yuanhao Zhai commited on
Commit
60b5ed2
1 Parent(s): 35a188c

add resize note

Browse files
Files changed (2) hide show
  1. app.py +13 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from pathlib import Path
2
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
@@ -19,8 +20,17 @@ def greet(input_image):
19
  with torch.no_grad():
20
  image = input_image
21
  image = np.array(image)
 
 
 
 
 
 
22
  dsm_image = torch.from_numpy(image).permute(2, 0, 1)
 
23
  image_size = image.shape[:2]
 
 
24
  image = img_to_tensor(
25
  image,
26
  normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
@@ -36,6 +46,9 @@ def greet(input_image):
36
  f"No manipulation found (manipulation probability {pred:.2f})."
37
  )
38
 
 
 
 
39
  overlay = draw_segmentation_masks(
40
  dsm_image, masks=out_map[0, ...] > opt.mask_threshold
41
  )
 
1
  from pathlib import Path
2
 
3
+ import albumentations as A
4
  import gradio as gr
5
  import numpy as np
6
  import torch
 
20
  with torch.no_grad():
21
  image = input_image
22
  image = np.array(image)
23
+ h, w = image.shape[:2]
24
+ if max(h, w) > 1024:
25
+ transform = A.LongestMaxSize(1024)
26
+ else:
27
+ transform = None
28
+
29
  dsm_image = torch.from_numpy(image).permute(2, 0, 1)
30
+
31
  image_size = image.shape[:2]
32
+ if transform is not None:
33
+ image = transform(image=image)["image"]
34
  image = img_to_tensor(
35
  image,
36
  normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
 
46
  f"No manipulation found (manipulation probability {pred:.2f})."
47
  )
48
 
49
+ if transform is not None:
50
+ output_string += f"\nNote: Image was too large ({h}, {w}) and was resized to fit the model, which may decrease accuracy. We recommend image size smaller than 1024x1024."
51
+
52
  overlay = draw_segmentation_masks(
53
  dsm_image, masks=out_map[0, ...] > opt.mask_threshold
54
  )
requirements.txt CHANGED
@@ -25,3 +25,4 @@ termcolor==2.4.0
25
  timm==0.9.12
26
  tqdm
27
  markupsafe==2.0.1
 
 
25
  timm==0.9.12
26
  tqdm
27
  markupsafe==2.0.1
28
+ gradio