nickyreinert-vml
commited on
Commit
•
a494f64
1
Parent(s):
45e93a3
fixing minor bug
Browse files- helpers.py +3 -2
helpers.py
CHANGED
@@ -47,14 +47,15 @@ def get_pipeline(config):
|
|
47 |
config["model"],
|
48 |
use_safetensors = get_bool(config["use_safetensors"]),
|
49 |
torch_dtype = get_data_type(config["data_type"]),
|
50 |
-
variant = get_variant(config["variant"])).to(config["device"])
|
51 |
else:
|
52 |
from diffusers import DiffusionPipeline
|
53 |
pipeline = DiffusionPipeline.from_pretrained(
|
54 |
config["model"],
|
55 |
use_safetensors = get_bool(config["use_safetensors"]),
|
56 |
torch_dtype = get_data_type(config["data_type"]),
|
57 |
-
variant = get_variant(config["variant"])
|
|
|
58 |
|
59 |
return pipeline
|
60 |
|
|
|
47 |
config["model"],
|
48 |
use_safetensors = get_bool(config["use_safetensors"]),
|
49 |
torch_dtype = get_data_type(config["data_type"]),
|
50 |
+
variant = get_variant(config["variant"])).to(config["device"])
|
51 |
else:
|
52 |
from diffusers import DiffusionPipeline
|
53 |
pipeline = DiffusionPipeline.from_pretrained(
|
54 |
config["model"],
|
55 |
use_safetensors = get_bool(config["use_safetensors"]),
|
56 |
torch_dtype = get_data_type(config["data_type"]),
|
57 |
+
variant = get_variant(config["variant"])
|
58 |
+
).to(config["device"])
|
59 |
|
60 |
return pipeline
|
61 |
|