Spaces:
Runtime error
Runtime error
RamAnanth1
commited on
Commit
•
9788c55
1
Parent(s):
09808dd
Create utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchvision.transforms as T
|
5 |
+
|
6 |
+
totensor = T.ToTensor()
|
7 |
+
topil = T.ToPILImage()
|
8 |
+
|
9 |
+
def recover_image(image, init_image, mask, background=False):
|
10 |
+
image = totensor(image)
|
11 |
+
mask = totensor(mask)
|
12 |
+
init_image = totensor(init_image)
|
13 |
+
if background:
|
14 |
+
result = mask * init_image + (1 - mask) * image
|
15 |
+
else:
|
16 |
+
result = mask * image + (1 - mask) * init_image
|
17 |
+
return topil(result)
|
18 |
+
|
19 |
+
def preprocess(image):
|
20 |
+
w, h = image.size
|
21 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
22 |
+
image = image.resize((w, h), resample=Image.LANCZOS)
|
23 |
+
image = np.array(image).astype(np.float32) / 255.0
|
24 |
+
image = image[None].transpose(0, 3, 1, 2)
|
25 |
+
image = torch.from_numpy(image)
|
26 |
+
return 2.0 * image - 1.0
|
27 |
+
|
28 |
+
def prepare_mask_and_masked_image(image, mask):
|
29 |
+
image = np.array(image.convert("RGB"))
|
30 |
+
image = image[None].transpose(0, 3, 1, 2)
|
31 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
32 |
+
|
33 |
+
mask = np.array(mask.convert("L"))
|
34 |
+
mask = mask.astype(np.float32) / 255.0
|
35 |
+
mask = mask[None, None]
|
36 |
+
mask[mask < 0.5] = 0
|
37 |
+
mask[mask >= 0.5] = 1
|
38 |
+
mask = torch.from_numpy(mask)
|
39 |
+
|
40 |
+
masked_image = image * (mask < 0.5)
|
41 |
+
|
42 |
+
return mask, masked_image
|
43 |
+
|
44 |
+
def prepare_image(image):
|
45 |
+
image = np.array(image.convert("RGB"))
|
46 |
+
image = image[None].transpose(0, 3, 1, 2)
|
47 |
+
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
48 |
+
|
49 |
+
return image[0]
|