darkstorm2150 commited on
Commit
9f4b8d0
1 Parent(s): 8ee47a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -138,9 +138,13 @@ def txt_to_img(
138
  ):
139
  pipe = MODELS[model_name].pipe_t2i
140
 
141
- if torch.cuda.is_available():
142
- pipe = pipe.to("cuda")
143
- pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
144
 
145
  result = pipe(
146
  prompt,
@@ -153,8 +157,9 @@ def txt_to_img(
153
  generator=generator,
154
  )
155
 
156
- pipe.to("cpu")
157
- torch.cuda.empty_cache()
 
158
 
159
  return replace_nsfw_images(result)
160
 
@@ -175,9 +180,13 @@ def img_to_img(
175
  ):
176
  pipe = model.pipe_i2i
177
 
178
- if torch.cuda.is_available():
179
- pipe = pipe.to("cuda")
180
- pipe.enable_xformers_memory_efficient_attention()
 
 
 
 
181
 
182
  ratio = min(height / img.height, width / img.width)
183
  img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
@@ -193,8 +202,9 @@ def img_to_img(
193
  generator=generator,
194
  )
195
 
196
- pipe.to("cpu")
197
- torch.cuda.empty_cache()
 
198
 
199
  return replace_nsfw_images(result)
200
 
 
138
  ):
139
  pipe = MODELS[model_name].pipe_t2i
140
 
141
+ if pipe is not None:
142
+ if torch.cuda.is_available():
143
+ pipe = pipe.to("cuda")
144
+ pipe.enable_xformers_memory_efficient_attention()
145
+
146
+ else:
147
+ raise ValueError(f"Unable to find pipeline for model: {model_name}")
148
 
149
  result = pipe(
150
  prompt,
 
157
  generator=generator,
158
  )
159
 
160
+ if pipe is not None:
161
+ pipe.to("cpu")
162
+ torch.cuda.empty_cache()
163
 
164
  return replace_nsfw_images(result)
165
 
 
180
  ):
181
  pipe = model.pipe_i2i
182
 
183
+ if pipe is not None:
184
+ if torch.cuda.is_available():
185
+ pipe = pipe.to("cuda")
186
+ pipe.enable_xformers_memory_efficient_attention()
187
+
188
+ else:
189
+ raise ValueError(f"Unable to find pipeline for model: {model_name}")
190
 
191
  ratio = min(height / img.height, width / img.width)
192
  img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
 
202
  generator=generator,
203
  )
204
 
205
+ if pipe is not None:
206
+ pipe.to("cpu")
207
+ torch.cuda.empty_cache()
208
 
209
  return replace_nsfw_images(result)
210