neverix commited on
Commit
a40667a
1 Parent(s): 2b04aed

Fix bug? (?)

Browse files
Files changed (1) hide show
  1. data_loader.py +2 -2
data_loader.py CHANGED
@@ -226,7 +226,7 @@ class FileDataset(Dataset):
226
  if "labels" in sample:
227
  # return UDP as 4chn XYZV float tensor
228
  sample["labels"] = torch.from_numpy(
229
- sample["labels"].transpose((2, 0, 1)))
230
  assert (sample["labels"].dtype == torch.float32)
231
 
232
  if "image_np" in sample:
@@ -270,4 +270,4 @@ class FileDataset(Dataset):
270
  "character_masks": character_masks
271
  })
272
  # do not make fake labels in inference
273
- return sample
 
226
  if "labels" in sample:
227
  # return UDP as 4chn XYZV float tensor
228
  sample["labels"] = torch.from_numpy(
229
+ sample["labels"].transpose((2, 0, 1)).astype(np.float32))
230
  assert (sample["labels"].dtype == torch.float32)
231
 
232
  if "image_np" in sample:
 
270
  "character_masks": character_masks
271
  })
272
  # do not make fake labels in inference
273
+ return sample