nightfury commited on
Commit
18e0cee
1 Parent(s): 8f1bed9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -53,7 +53,11 @@ def download_image(url):
53
  response = requests.get(url)
54
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
55
 
56
- device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
57
 
58
  model_id_or_path = "CompVis/stable-diffusion-v1-4"
59
  pipe = StableDiffusionInpaintingPipeline.from_pretrained(
@@ -118,7 +122,7 @@ def predict(radio, dict, word_mask, prompt=""):
118
 
119
  #model = model.to(torch.device(device))
120
  img = img.to(torch.device(device))
121
- prompt = labels.to(torch.device(device))
122
  #---------
123
 
124
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
@@ -149,7 +153,7 @@ def predict(radio, dict, word_mask, prompt=""):
149
 
150
  #model = model.to(torch.device(device))
151
  img = img.to(torch.device(device))
152
- prompt = labels.to(torch.device(device))
153
 
154
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
155
  filename = f"{uuid.uuid4()}.png"
 
53
  response = requests.get(url)
54
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
55
 
56
+ #device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
59
+ print(“The model will be running on”, device, “device”)
60
+ # Convert model parameters and buffers to CPU or Cuda
61
 
62
  model_id_or_path = "CompVis/stable-diffusion-v1-4"
63
  pipe = StableDiffusionInpaintingPipeline.from_pretrained(
 
122
 
123
  #model = model.to(torch.device(device))
124
  img = img.to(torch.device(device))
125
+ prompt = prompt.to(torch.device(device))
126
  #---------
127
 
128
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
 
153
 
154
  #model = model.to(torch.device(device))
155
  img = img.to(torch.device(device))
156
+ prompt = prompt.to(torch.device(device))
157
 
158
  init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
159
  filename = f"{uuid.uuid4()}.png"