add resize note
Browse files- app.py +13 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from pathlib import Path
|
2 |
|
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import torch
|
@@ -19,8 +20,17 @@ def greet(input_image):
|
|
19 |
with torch.no_grad():
|
20 |
image = input_image
|
21 |
image = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
dsm_image = torch.from_numpy(image).permute(2, 0, 1)
|
|
|
23 |
image_size = image.shape[:2]
|
|
|
|
|
24 |
image = img_to_tensor(
|
25 |
image,
|
26 |
normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
|
@@ -36,6 +46,9 @@ def greet(input_image):
|
|
36 |
f"No manipulation found (manipulation probability {pred:.2f})."
|
37 |
)
|
38 |
|
|
|
|
|
|
|
39 |
overlay = draw_segmentation_masks(
|
40 |
dsm_image, masks=out_map[0, ...] > opt.mask_threshold
|
41 |
)
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
+
import albumentations as A
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
20 |
with torch.no_grad():
|
21 |
image = input_image
|
22 |
image = np.array(image)
|
23 |
+
h, w = image.shape[:2]
|
24 |
+
if max(h, w) > 1024:
|
25 |
+
transform = A.LongestMaxSize(1024)
|
26 |
+
else:
|
27 |
+
transform = None
|
28 |
+
|
29 |
dsm_image = torch.from_numpy(image).permute(2, 0, 1)
|
30 |
+
|
31 |
image_size = image.shape[:2]
|
32 |
+
if transform is not None:
|
33 |
+
image = transform(image=image)["image"]
|
34 |
image = img_to_tensor(
|
35 |
image,
|
36 |
normalize={"mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD},
|
|
|
46 |
f"No manipulation found (manipulation probability {pred:.2f})."
|
47 |
)
|
48 |
|
49 |
+
if transform is not None:
|
50 |
+
output_string += f"\nNote: Image was too large ({h}, {w}) and was resized to fit the model, which may decrease accuracy. We recommend image size smaller than 1024x1024."
|
51 |
+
|
52 |
overlay = draw_segmentation_masks(
|
53 |
dsm_image, masks=out_map[0, ...] > opt.mask_threshold
|
54 |
)
|
requirements.txt
CHANGED
@@ -25,3 +25,4 @@ termcolor==2.4.0
|
|
25 |
timm==0.9.12
|
26 |
tqdm
|
27 |
markupsafe==2.0.1
|
|
|
|
25 |
timm==0.9.12
|
26 |
tqdm
|
27 |
markupsafe==2.0.1
|
28 |
+
gradio
|