thak123 commited on
Commit
337c718
1 Parent(s): dfe2e18

Update engine.py

Browse files
Files changed (1) hide show
  1. engine.py +3 -1
engine.py CHANGED
@@ -108,7 +108,9 @@ def predict_fn(data_loader, model, device, extract_features=False):
108
  mask=mask,
109
  token_type_ids=token_type_ids
110
  ).cpu().detach().numpy().tolist())
111
-
 
 
112
  fin_outputs.extend(torch.argmax(
113
  outputs, dim=1).cpu().detach().numpy().tolist())
114
 
 
108
  mask=mask,
109
  token_type_ids=token_type_ids
110
  ).cpu().detach().numpy().tolist())
111
+ print("1",torch.argmax(outputs, dim=1))
112
+ print("2",torch.argmax(outputs, dim=1).cpu())
113
+ print("3",torch.argmax(outputs, dim=1).cpu().numpy())
114
  fin_outputs.extend(torch.argmax(
115
  outputs, dim=1).cpu().detach().numpy().tolist())
116