Spaces:
Sleeping
Sleeping
Update src/inference.py
Browse files- src/inference.py +3 -7
src/inference.py
CHANGED
@@ -17,7 +17,6 @@ from src.utils import (
|
|
17 |
tensor_lab2rgb
|
18 |
)
|
19 |
import numpy as np
|
20 |
-
from tqdm import tqdm
|
21 |
|
22 |
class SwinTExCo:
|
23 |
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
|
@@ -62,13 +61,13 @@ class SwinTExCo:
|
|
62 |
size=(H,W),
|
63 |
mode="bilinear",
|
64 |
align_corners=False)
|
65 |
-
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict
|
66 |
large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
|
67 |
-
return large_current_rgb_predict
|
68 |
|
69 |
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
|
70 |
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
|
71 |
-
large_IA_l = large_IA_lab[:, 0:1, :, :]
|
72 |
|
73 |
IA_lab = self.processor(curr_frame)
|
74 |
IA_lab = IA_lab.unsqueeze(0).to(self.device)
|
@@ -113,9 +112,7 @@ class SwinTExCo:
|
|
113 |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
|
114 |
features_B = self.embed_net(I_reference_rgb)
|
115 |
|
116 |
-
#PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame")
|
117 |
while video.isOpened():
|
118 |
-
#PBAR.update(1)
|
119 |
ret, curr_frame = video.read()
|
120 |
|
121 |
if not ret:
|
@@ -130,7 +127,6 @@ class SwinTExCo:
|
|
130 |
|
131 |
yield IA_predict_rgb
|
132 |
|
133 |
-
#PBAR.close()
|
134 |
video.release()
|
135 |
|
136 |
def predict_image(self, image, ref_image):
|
|
|
17 |
tensor_lab2rgb
|
18 |
)
|
19 |
import numpy as np
|
|
|
20 |
|
21 |
class SwinTExCo:
|
22 |
def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
|
|
|
61 |
size=(H,W),
|
62 |
mode="bilinear",
|
63 |
align_corners=False)
|
64 |
+
large_IA_l = torch.cat((large_IA_l, large_current_ab_predict), dim=1)
|
65 |
large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
|
66 |
+
return large_current_rgb_predict.cpu()
|
67 |
|
68 |
def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
|
69 |
large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
|
70 |
+
large_IA_l = large_IA_lab[:, 0:1, :, :].to(self.device)
|
71 |
|
72 |
IA_lab = self.processor(curr_frame)
|
73 |
IA_lab = IA_lab.unsqueeze(0).to(self.device)
|
|
|
112 |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
|
113 |
features_B = self.embed_net(I_reference_rgb)
|
114 |
|
|
|
115 |
while video.isOpened():
|
|
|
116 |
ret, curr_frame = video.read()
|
117 |
|
118 |
if not ret:
|
|
|
127 |
|
128 |
yield IA_predict_rgb
|
129 |
|
|
|
130 |
video.release()
|
131 |
|
132 |
def predict_image(self, image, ref_image):
|