akiyamasho commited on
Commit
91a3469
1 Parent(s): 0055c8e

MAINT: logs for cuda

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. network/Transformer.py +1 -1
app.py CHANGED
@@ -91,8 +91,10 @@ def inference(img, style):
91
  input_image = -1 + 2 * input_image
92
 
93
  if enable_gpu:
 
94
  input_image = Variable(input_image).cuda()
95
  else:
 
96
  input_image = Variable(input_image).float()
97
 
98
  # forward
91
  input_image = -1 + 2 * input_image
92
 
93
  if enable_gpu:
94
+ logger.info(f"CUDA found. Using GPU.")
95
  input_image = Variable(input_image).cuda()
96
  else:
97
+ logger.info(f"CUDA not found. Using CPU.")
98
  input_image = Variable(input_image).float()
99
 
100
  # forward
network/Transformer.py CHANGED
@@ -146,7 +146,7 @@ class Transformer(nn.Module):
146
 
147
  y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
148
  y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
149
- y = F.tanh(self.deconv03_1(self.refpad12_1(y)))
150
 
151
  return y
152
 
146
 
147
  y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
148
  y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
149
+ y = torch.tanh(self.deconv03_1(self.refpad12_1(y)))
150
 
151
  return y
152