nightfury commited on
Commit
07c78b9
1 Parent(s): c84b0ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -59,7 +59,7 @@ model_id_or_path = "CompVis/stable-diffusion-v1-4"
59
  pipe = StableDiffusionInpaintingPipeline.from_pretrained(
60
  model_id_or_path,
61
  revision="fp16",
62
- torch_dtype=torch.long, #float16
63
  use_auth_token=auth_token
64
  )
65
  #self.register_buffer('n_', ...)
@@ -87,7 +87,7 @@ def predict(radio, dict, word_mask, prompt=""):
87
  elif(radio == "type what to keep"):
88
  img = transform(dict["image"]).squeeze(0)
89
  word_masks = [word_mask]
90
- with torch.no_grad():
91
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
92
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
93
  filename = f"{uuid.uuid4()}.png"
@@ -101,7 +101,7 @@ def predict(radio, dict, word_mask, prompt=""):
101
  else:
102
  img = transform(dict["image"]).unsqueeze(0)
103
  word_masks = [word_mask]
104
- with torch.no_grad():
105
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
106
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
107
  filename = f"{uuid.uuid4()}.png"
 
59
  pipe = StableDiffusionInpaintingPipeline.from_pretrained(
60
  model_id_or_path,
61
  revision="fp16",
62
+ torch_dtype=torch.float16, #float16
63
  use_auth_token=auth_token
64
  )
65
  #self.register_buffer('n_', ...)
 
87
  elif(radio == "type what to keep"):
88
  img = transform(dict["image"]).squeeze(0)
89
  word_masks = [word_mask]
90
+ with torch.cuda.amp.autocast(): #with torch.no_grad():
91
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
92
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
93
  filename = f"{uuid.uuid4()}.png"
 
101
  else:
102
  img = transform(dict["image"]).unsqueeze(0)
103
  word_masks = [word_mask]
104
+ with torch.cuda.amp.autocast(): #with torch.no_grad():
105
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
106
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
107
  filename = f"{uuid.uuid4()}.png"