tpob commited on
Commit
4f17a92
1 Parent(s): dbfac41

fix: incorrect condition control flow

Browse files

this addresses the issue when using GeneformerPrecollator if not using numpy array as tensor backend, a ValueError is raise like `type of tensor([ 2943, 24469, 12039, ..., 4506, 13856, 10511]) unknown: <class 'torch.Tensor'>. Should be one of a python, numpy, pytorch or tensorflow object.`. I believe this is a bug, as I am clearly using a pytorch tensor. And in huggingface `transformers` library, [similar code wrote here](https://github.com/huggingface/transformers/blob/21dc5859421cf0d7d82d374b10f533611745a8c5/src/transformers/feature_extraction_sequence_utils.py#L166)

Files changed (1) hide show
  1. geneformer/pretrainer.py +1 -1
geneformer/pretrainer.py CHANGED
@@ -381,7 +381,7 @@ class GeneformerPreCollator(SpecialTokensMixin):
381
  return_tensors = "tf" if return_tensors is None else return_tensors
382
  elif is_torch_available() and _is_torch(first_element):
383
  return_tensors = "pt" if return_tensors is None else return_tensors
384
- if isinstance(first_element, np.ndarray):
385
  return_tensors = "np" if return_tensors is None else return_tensors
386
  else:
387
  raise ValueError(
 
381
  return_tensors = "tf" if return_tensors is None else return_tensors
382
  elif is_torch_available() and _is_torch(first_element):
383
  return_tensors = "pt" if return_tensors is None else return_tensors
384
+ elif isinstance(first_element, np.ndarray):
385
  return_tensors = "np" if return_tensors is None else return_tensors
386
  else:
387
  raise ValueError(