Spaces:
Sleeping
Sleeping
cocktailpeanut
commited on
Commit
•
0daeefc
1
Parent(s):
d5b9c19
mps
Browse files
app.py
CHANGED
@@ -38,7 +38,14 @@ with open(os.path.join(CONFIG), "r") as f:
|
|
38 |
|
39 |
cfg = dict2namespace(config)
|
40 |
|
41 |
-
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
|
43 |
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
|
44 |
model = model.to(device)
|
@@ -148,4 +155,4 @@ demo = gr.Interface(
|
|
148 |
)
|
149 |
|
150 |
if __name__ == "__main__":
|
151 |
-
demo.launch()
|
|
|
38 |
|
39 |
cfg = dict2namespace(config)
|
40 |
|
41 |
+
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
42 |
+
if torch.backends.mps.is_available():
|
43 |
+
device = "mps"
|
44 |
+
torch_dtype = torch.float32
|
45 |
+
elif torch.cuda.is_available():
|
46 |
+
device = "cuda"
|
47 |
+
else:
|
48 |
+
device = "cpu"
|
49 |
model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
|
50 |
middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
|
51 |
model = model.to(device)
|
|
|
155 |
)
|
156 |
|
157 |
if __name__ == "__main__":
|
158 |
+
demo.launch()
|