Update pintar.py
Browse files
pintar.py
CHANGED
@@ -29,6 +29,8 @@ def Normalize(inputs):
|
|
29 |
l = inputs[:, :, 0:1]
|
30 |
ab = inputs[:, :, 1:3]
|
31 |
l = l - 50
|
|
|
|
|
32 |
lab = np.concatenate((l, ab), 2)
|
33 |
return lab.astype('float32')
|
34 |
|
@@ -77,12 +79,12 @@ if __name__ == "__main__":
|
|
77 |
img_lab = img_lab.to(device).unsqueeze(0)
|
78 |
|
79 |
with torch.no_grad():
|
80 |
-
img_resize = F.interpolate(img_lab
|
81 |
-
img_L_resize = F.interpolate(img_resize[:, :1, :, :]
|
82 |
|
83 |
color_vector = colorEncoder(img_resize)
|
84 |
fake_ab = colorUNet((img_L_resize, color_vector))
|
85 |
-
fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear',
|
86 |
|
87 |
fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
|
88 |
fake_img = Lab2RGB_out(fake_img)
|
|
|
29 |
l = inputs[:, :, 0:1]
|
30 |
ab = inputs[:, :, 1:3]
|
31 |
l = l - 50
|
32 |
+
l = l / 50 # Normalizar L al rango [-1, 1]
|
33 |
+
ab = ab / 110 # Normalizar ab al rango [-1, 1]
|
34 |
lab = np.concatenate((l, ab), 2)
|
35 |
return lab.astype('float32')
|
36 |
|
|
|
79 |
img_lab = img_lab.to(device).unsqueeze(0)
|
80 |
|
81 |
with torch.no_grad():
|
82 |
+
img_resize = F.interpolate(img_lab, size=(256, 256), mode='bilinear', align_corners=False)
|
83 |
+
img_L_resize = F.interpolate(img_resize[:, :1, :, :], size=(256, 256), mode='bilinear', align_corners=False)
|
84 |
|
85 |
color_vector = colorEncoder(img_resize)
|
86 |
fake_ab = colorUNet((img_L_resize, color_vector))
|
87 |
+
fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear', align_corners=False)
|
88 |
|
89 |
fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
|
90 |
fake_img = Lab2RGB_out(fake_img)
|