duongttr commited on
Commit
da5e78e
·
1 Parent(s): 2b79d08

Update src/inference.py

Browse files
Files changed (1) hide show
  1. src/inference.py +3 -7
src/inference.py CHANGED
@@ -17,7 +17,6 @@ from src.utils import (
17
  tensor_lab2rgb
18
  )
19
  import numpy as np
20
- from tqdm import tqdm
21
 
22
  class SwinTExCo:
23
  def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
@@ -62,13 +61,13 @@ class SwinTExCo:
62
  size=(H,W),
63
  mode="bilinear",
64
  align_corners=False)
65
- large_IA_l = torch.cat((large_IA_l, large_current_ab_predict.cpu()), dim=1)
66
  large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
67
- return large_current_rgb_predict
68
 
69
  def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
70
  large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
71
- large_IA_l = large_IA_lab[:, 0:1, :, :]
72
 
73
  IA_lab = self.processor(curr_frame)
74
  IA_lab = IA_lab.unsqueeze(0).to(self.device)
@@ -113,9 +112,7 @@ class SwinTExCo:
113
  I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
114
  features_B = self.embed_net(I_reference_rgb)
115
 
116
- #PBAR = tqdm(total=int(video.get(cv2.CAP_PROP_FRAME_COUNT)), desc="Colorizing video", unit="frame")
117
  while video.isOpened():
118
- #PBAR.update(1)
119
  ret, curr_frame = video.read()
120
 
121
  if not ret:
@@ -130,7 +127,6 @@ class SwinTExCo:
130
 
131
  yield IA_predict_rgb
132
 
133
- #PBAR.close()
134
  video.release()
135
 
136
  def predict_image(self, image, ref_image):
 
17
  tensor_lab2rgb
18
  )
19
  import numpy as np
 
20
 
21
  class SwinTExCo:
22
  def __init__(self, weights_path, swin_backbone='swinv2-cr-t-224', device=None):
 
61
  size=(H,W),
62
  mode="bilinear",
63
  align_corners=False)
64
+ large_IA_l = torch.cat((large_IA_l, large_current_ab_predict), dim=1)
65
  large_current_rgb_predict = tensor_lab2rgb(large_IA_l)
66
+ return large_current_rgb_predict.cpu()
67
 
68
  def __proccess_sample(self, curr_frame, I_last_lab_predict, I_reference_lab, features_B):
69
  large_IA_lab = ToTensor()(RGB2Lab()(curr_frame)).unsqueeze(0)
70
+ large_IA_l = large_IA_lab[:, 0:1, :, :].to(self.device)
71
 
72
  IA_lab = self.processor(curr_frame)
73
  IA_lab = IA_lab.unsqueeze(0).to(self.device)
 
112
  I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(self.device)
113
  features_B = self.embed_net(I_reference_rgb)
114
 
 
115
  while video.isOpened():
 
116
  ret, curr_frame = video.read()
117
 
118
  if not ret:
 
127
 
128
  yield IA_predict_rgb
129
 
 
130
  video.release()
131
 
132
  def predict_image(self, image, ref_image):