Spaces:
Running
Running
TedYeh
commited on
Commit
•
bf809f1
1
Parent(s):
a38dfb6
update predictor
Browse files- predictor.py +2 -2
predictor.py
CHANGED
@@ -198,7 +198,7 @@ def evaluation(model, epoch, device, dataloaders):
|
|
198 |
print(preds)
|
199 |
|
200 |
def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
201 |
-
device = torch.device("cuda")
|
202 |
translator= Translator(to_lang="zh-TW")
|
203 |
|
204 |
model = CUPredictor()
|
@@ -218,7 +218,7 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
|
218 |
image_tensor = trans(inp_img)
|
219 |
image_tensor = image_tensor.unsqueeze(0)
|
220 |
with torch.no_grad():
|
221 |
-
inputs = image_tensor
|
222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
223 |
_, preds = torch.max(outputs_c, 1)
|
224 |
idx = preds.numpy()[0]
|
|
|
198 |
print(preds)
|
199 |
|
200 |
def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
201 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
202 |
translator= Translator(to_lang="zh-TW")
|
203 |
|
204 |
model = CUPredictor()
|
|
|
218 |
image_tensor = trans(inp_img)
|
219 |
image_tensor = image_tensor.unsqueeze(0)
|
220 |
with torch.no_grad():
|
221 |
+
inputs = image_tensor.to(device)
|
222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
223 |
_, preds = torch.max(outputs_c, 1)
|
224 |
idx = preds.numpy()[0]
|