jadechoghari commited on
Commit
e662c48
·
verified ·
1 Parent(s): 57d2891

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1 -1
pipeline.py CHANGED
@@ -84,7 +84,7 @@ class MARModel(DiffusionPipeline):
84
  cfg_schedule = kwargs.get("cfg_schedule", "constant")
85
  temperature = kwargs.get("temperature", 1.0)
86
  class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
87
- class_labels.long().to(device)
88
 
89
  # generate the tokens and images
90
  with torch.cuda.amp.autocast():
 
84
  cfg_schedule = kwargs.get("cfg_schedule", "constant")
85
  temperature = kwargs.get("temperature", 1.0)
86
  class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
87
+ class_labels = torch.Tensor(class_labels).long().to(device)
88
 
89
  # generate the tokens and images
90
  with torch.cuda.amp.autocast():