Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -59,7 +59,7 @@ model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
|
59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
60 |
model_id_or_path,
|
61 |
revision="fp16",
|
62 |
-
torch_dtype=torch.
|
63 |
use_auth_token=auth_token
|
64 |
)
|
65 |
#self.register_buffer('n_', ...)
|
@@ -87,7 +87,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
87 |
elif(radio == "type what to keep"):
|
88 |
img = transform(dict["image"]).squeeze(0)
|
89 |
word_masks = [word_mask]
|
90 |
-
with torch.no_grad():
|
91 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
92 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
93 |
filename = f"{uuid.uuid4()}.png"
|
@@ -101,7 +101,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
101 |
else:
|
102 |
img = transform(dict["image"]).unsqueeze(0)
|
103 |
word_masks = [word_mask]
|
104 |
-
with torch.no_grad():
|
105 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
106 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
107 |
filename = f"{uuid.uuid4()}.png"
|
|
|
59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
60 |
model_id_or_path,
|
61 |
revision="fp16",
|
62 |
+
torch_dtype=torch.float16, #float16
|
63 |
use_auth_token=auth_token
|
64 |
)
|
65 |
#self.register_buffer('n_', ...)
|
|
|
87 |
elif(radio == "type what to keep"):
|
88 |
img = transform(dict["image"]).squeeze(0)
|
89 |
word_masks = [word_mask]
|
90 |
+
with torch.cuda.amp.autocast(): #with torch.no_grad():
|
91 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
92 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
93 |
filename = f"{uuid.uuid4()}.png"
|
|
|
101 |
else:
|
102 |
img = transform(dict["image"]).unsqueeze(0)
|
103 |
word_masks = [word_mask]
|
104 |
+
with torch.cuda.amp.autocast(): #with torch.no_grad():
|
105 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
106 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
107 |
filename = f"{uuid.uuid4()}.png"
|