nightfury commited on
Commit
ae6071d
1 Parent(s): ff0ef6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -52,7 +52,8 @@ def predict(radio, dict, word_mask, prompt=""):
52
  if(radio == "draw a mask above"):
53
  #with autocast("cuda"):
54
  #with autocast(device): #enable=(False if device=='cpu' else True)):
55
- with autocast(enabled=True, dtype=torch.bfloat16):
 
56
  init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
57
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
58
  else:
@@ -71,7 +72,8 @@ def predict(radio, dict, word_mask, prompt=""):
71
  os.remove(filename)
72
  #with autocast("cuda"):
73
  #with autocast(device): #enable=(False if device=='cpu' else True)):
74
- with autocast(enabled=True, dtype=torch.bfloat16):
 
75
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
76
  return images[0]
77
 
52
  if(radio == "draw a mask above"):
53
  #with autocast("cuda"):
54
  #with autocast(device): #enable=(False if device=='cpu' else True)):
55
+ #with autocast(enabled=True, dtype=torch.bfloat16):
56
+ with torch.cuda.amp.autocast(True):
57
  init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
58
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
59
  else:
72
  os.remove(filename)
73
  #with autocast("cuda"):
74
  #with autocast(device): #enable=(False if device=='cpu' else True)):
75
+ #with autocast(enabled=True, dtype=torch.bfloat16):
76
+ with torch.cuda.amp.autocast(True):
77
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
78
  return images[0]
79