nightfury commited on
Commit
b202164
1 Parent(s): 49dc097

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -51,7 +51,8 @@ transform = transforms.Compose([
51
  def predict(radio, dict, word_mask, prompt=""):
52
  if(radio == "draw a mask above"):
53
  #with autocast("cuda"):
54
- with autocast(device): #enable=(False if device=='cpu' else True)):
 
55
  init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
56
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
57
  else:
@@ -69,7 +70,8 @@ def predict(radio, dict, word_mask, prompt=""):
69
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
70
  os.remove(filename)
71
  #with autocast("cuda"):
72
- with autocast(device): #enable=(False if device=='cpu' else True)):
 
73
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
74
  return images[0]
75
 
 
51
  def predict(radio, dict, word_mask, prompt=""):
52
  if(radio == "draw a mask above"):
53
  #with autocast("cuda"):
54
+ #with autocast(device): #enable=(False if device=='cpu' else True)):
55
+ with autocast(enabled=True, dtype=torch.bfloat16, device='cpu'):
56
  init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
57
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
58
  else:
 
70
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
71
  os.remove(filename)
72
  #with autocast("cuda"):
73
+ #with autocast(device): #enable=(False if device=='cpu' else True)):
74
+ with autocast(enabled=True, dtype=torch.bfloat16, device='cpu'):
75
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
76
  return images[0]
77