Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -53,7 +53,11 @@ def download_image(url):
|
|
53 |
response = requests.get(url)
|
54 |
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
55 |
|
56 |
-
device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
57 |
|
58 |
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
@@ -118,7 +122,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
118 |
|
119 |
#model = model.to(torch.device(device))
|
120 |
img = img.to(torch.device(device))
|
121 |
-
prompt =
|
122 |
#---------
|
123 |
|
124 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
@@ -149,7 +153,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
149 |
|
150 |
#model = model.to(torch.device(device))
|
151 |
img = img.to(torch.device(device))
|
152 |
-
prompt =
|
153 |
|
154 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
155 |
filename = f"{uuid.uuid4()}.png"
|
|
|
53 |
response = requests.get(url)
|
54 |
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
55 |
|
56 |
+
#device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
|
57 |
+
|
58 |
+
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
|
59 |
+
print(“The model will be running on”, device, “device”)
|
60 |
+
# Convert model parameters and buffers to CPU or Cuda
|
61 |
|
62 |
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
63 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
|
|
122 |
|
123 |
#model = model.to(torch.device(device))
|
124 |
img = img.to(torch.device(device))
|
125 |
+
prompt = prompt.to(torch.device(device))
|
126 |
#---------
|
127 |
|
128 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
|
|
153 |
|
154 |
#model = model.to(torch.device(device))
|
155 |
img = img.to(torch.device(device))
|
156 |
+
prompt = prompt.to(torch.device(device))
|
157 |
|
158 |
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
159 |
filename = f"{uuid.uuid4()}.png"
|