cocktailpeanut commited on
Commit
afffb74
1 Parent(s): 6e73b66
Files changed (1) hide show
  1. app.py +10 -1
app.py CHANGED
@@ -17,6 +17,14 @@ from src.pix2pix_turbo import Pix2Pix_Turbo
17
 
18
  model = Pix2Pix_Turbo("sketch_to_image_stochastic")
19
 
 
 
 
 
 
 
 
 
20
  style_list = [
21
  {
22
  "name": "No Style",
@@ -86,7 +94,8 @@ def run(image, prompt, prompt_template, style_name, seed, val_r):
86
  image_pil = TF.to_pil_image(image_t.to(torch.float32))
87
  print(f"r_val={val_r}, seed={seed}")
88
  with torch.no_grad():
89
- c_t = image_t.unsqueeze(0).cuda().float()
 
90
  torch.manual_seed(seed)
91
  B,C,H,W = c_t.shape
92
  noise = torch.randn((1,4,H//8, W//8), device=c_t.device)
 
17
 
18
  model = Pix2Pix_Turbo("sketch_to_image_stochastic")
19
 
20
+ if torch.backends.mps.is_available():
21
+ device = "mps"
22
+ #torch_dtype = torch.float32
23
+ elif torch.cuda.is_available():
24
+ device = "cuda"
25
+ else:
26
+ device = "cpu"
27
+
28
  style_list = [
29
  {
30
  "name": "No Style",
 
94
  image_pil = TF.to_pil_image(image_t.to(torch.float32))
95
  print(f"r_val={val_r}, seed={seed}")
96
  with torch.no_grad():
97
+ #c_t = image_t.unsqueeze(0).cuda().float()
98
+ c_t = image_t.unsqueeze(0).to(device).float()
99
  torch.manual_seed(seed)
100
  B,C,H,W = c_t.shape
101
  noise = torch.randn((1,4,H//8, W//8), device=c_t.device)