natexcvi commited on
Commit
cad80c7
1 Parent(s): 45dae24

Fix fecnet

Browse files
Files changed (1) hide show
  1. model/fecnet.py +4 -3
model/fecnet.py CHANGED
@@ -53,10 +53,10 @@ class FECNetModel:
53
  model_weights = torch.load(model_path, map_location=torch.device("cpu"))
54
  model.load_state_dict(model_weights)
55
  model.eval()
56
- return model.double()
57
 
58
  def predict(self, image: np.ndarray):
59
- pred = self.model.forward(image)
60
  return pred
61
 
62
  def distance(a, b):
@@ -64,9 +64,10 @@ class FECNetModel:
64
 
65
  def embed_image(self, image) -> np.ndarray:
66
  image = cv2.imdecode(image, cv2.IMREAD_COLOR)
 
67
  image = cv2.resize(image, (224, 224))
68
  image = np.transpose(image, (2, 0, 1))
69
  image = np.expand_dims(image, axis=0)
70
- image = torch.from_numpy(image).double()
71
  pred = self.predict(image)
72
  return pred.detach().numpy()
 
53
  model_weights = torch.load(model_path, map_location=torch.device("cpu"))
54
  model.load_state_dict(model_weights)
55
  model.eval()
56
+ return model
57
 
58
  def predict(self, image: np.ndarray):
59
+ pred = self.model(image)
60
  return pred
61
 
62
  def distance(a, b):
 
64
 
65
  def embed_image(self, image) -> np.ndarray:
66
  image = cv2.imdecode(image, cv2.IMREAD_COLOR)
67
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
  image = cv2.resize(image, (224, 224))
69
  image = np.transpose(image, (2, 0, 1))
70
  image = np.expand_dims(image, axis=0)
71
+ image = torch.from_numpy(image.astype(np.float32))
72
  pred = self.predict(image)
73
  return pred.detach().numpy()