lsb commited on
Commit
12d035e
1 Parent(s): 7c3d7a0
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -11,14 +11,14 @@ import numpy as np
11
  from PIL import Image
12
  from datetime import datetime
13
 
14
- preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
15
- preferred_dtype = torch.float32 if preferred_device == 'cpu' else torch.float16
16
 
17
  def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
18
 
19
  segmodel = load_learner("camvid-512.pkl")
20
 
21
- if preferred_device == "cuda":
22
  segmodel = segmodel.to_fp16()
23
 
24
  inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
@@ -29,7 +29,7 @@ inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
29
 
30
  working_size = (512, 512)
31
 
32
- default_inpainting_prompt = "watercolor of a leafy pedestrian mall at golden hour with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
33
 
34
  seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
35
  'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
@@ -64,7 +64,7 @@ def app(img, prompt):
64
  image=img,
65
  mask_image=mask,
66
  strength=0.95,
67
- num_inference_steps=50,
68
  ).images[0]
69
  end_time = datetime.now().timestamp()
70
  draw = ImageDraw.Draw(overlay_img)
 
11
  from PIL import Image
12
  from datetime import datetime
13
 
14
+ preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
15
+ preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
16
 
17
  def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
18
 
19
  segmodel = load_learner("camvid-512.pkl")
20
 
21
+ if preferred_dtype == torch.float16:
22
  segmodel = segmodel.to_fp16()
23
 
24
  inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
 
29
 
30
  working_size = (512, 512)
31
 
32
+ default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
33
 
34
  seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
35
  'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
 
64
  image=img,
65
  mask_image=mask,
66
  strength=0.95,
67
+ num_inference_steps=20,
68
  ).images[0]
69
  end_time = datetime.now().timestamp()
70
  draw = ImageDraw.Draw(overlay_img)