Update pipeline.py
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -84,7 +84,7 @@ class MARModel(DiffusionPipeline):
|
|
84 |
cfg_schedule = kwargs.get("cfg_schedule", "constant")
|
85 |
temperature = kwargs.get("temperature", 1.0)
|
86 |
class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
|
87 |
-
class_labels.long().to(device)
|
88 |
|
89 |
# generate the tokens and images
|
90 |
with torch.cuda.amp.autocast():
|
|
|
84 |
cfg_schedule = kwargs.get("cfg_schedule", "constant")
|
85 |
temperature = kwargs.get("temperature", 1.0)
|
86 |
class_labels = kwargs.get("class_labels", [207, 360, 388, 113, 355, 980, 323, 979])
|
87 |
+
class_labels = torch.Tensor(class_labels).long().to(device)
|
88 |
|
89 |
# generate the tokens and images
|
90 |
with torch.cuda.amp.autocast():
|