darkstorm2150 commited on
Commit
ba2b82b
1 Parent(s): 0e6b2f0

Update app.py

Browse files

Restoring until fix is found later

Files changed (1) hide show
  1. app.py +27 -34
app.py CHANGED
@@ -22,8 +22,18 @@ class Model:
22
  def __init__(self, name, path=""):
23
  self.name = name
24
  self.path = path
25
- self.pipe_t2i = None
26
- self.pipe_i2i = None
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  models = [
@@ -41,19 +51,6 @@ MODELS = {m.name: m for m in models}
41
 
42
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
43
 
44
- def get_model(name):
45
- model = MODELS[name]
46
-
47
- if model.pipe_t2i is None:
48
- model.pipe_t2i = StableDiffusionPipeline.from_pretrained(
49
- model.path, torch_dtype=torch.float16, safety_checker=SAFETY_CHECKER
50
- )
51
- model.pipe_t2i.scheduler = DPMSolverMultistepScheduler.from_config(
52
- model.pipe_t2i.scheduler.config
53
- )
54
- model.pipe_i2i = StableDiffusionImg2ImgPipeline(**model.pipe_t2i.components)
55
-
56
- return model
57
 
58
  def error_str(error, title="Error"):
59
  return (
@@ -63,6 +60,7 @@ def error_str(error, title="Error"):
63
  else ""
64
  )
65
 
 
66
  def inference(
67
  model_name,
68
  prompt,
@@ -137,12 +135,9 @@ def txt_to_img(
137
  ):
138
  pipe = MODELS[model_name].pipe_t2i
139
 
140
- if pipe is not None:
141
- if torch.cuda.is_available():
142
- pipe = pipe.to("cuda")
143
- pipe.enable_xformers_memory_efficient_attention()
144
- else:
145
- raise ValueError(f"Unable to find pipeline for model: {model_name}")
146
 
147
  result = pipe(
148
  prompt,
@@ -155,12 +150,12 @@ def txt_to_img(
155
  generator=generator,
156
  )
157
 
158
- if pipe is not None:
159
- pipe.to("cpu")
160
- torch.cuda.empty_cache()
161
 
162
  return replace_nsfw_images(result)
163
 
 
164
  def img_to_img(
165
  model_name,
166
  prompt,
@@ -175,14 +170,11 @@ def img_to_img(
175
  generator,
176
  seed,
177
  ):
178
- pipe = model.pipe_i2i
179
 
180
- if pipe is not None:
181
- if torch.cuda.is_available():
182
- pipe = pipe.to("cuda")
183
- pipe.enable_xformers_memory_efficient_attention()
184
- else:
185
- raise ValueError(f"Unable to find pipeline for model: {model_name}")
186
 
187
  ratio = min(height / img.height, width / img.width)
188
  img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
@@ -198,18 +190,19 @@ def img_to_img(
198
  generator=generator,
199
  )
200
 
201
- if pipe is not None:
202
- pipe.to("cpu")
203
- torch.cuda.empty_cache()
204
 
205
  return replace_nsfw_images(result)
206
 
 
207
  def replace_nsfw_images(results):
208
  for i in range(len(results.images)):
209
  if results.nsfw_content_detected[i]:
210
  results.images[i] = Image.open("nsfw.png")
211
  return results.images
212
 
 
213
  with gr.Blocks(css="style.css") as demo:
214
  gr.HTML(
215
  f"""
 
22
  def __init__(self, name, path=""):
23
  self.name = name
24
  self.path = path
25
+
26
+ if path != "":
27
+ self.pipe_t2i = StableDiffusionPipeline.from_pretrained(
28
+ path, torch_dtype=torch.float16, safety_checker=SAFETY_CHECKER
29
+ )
30
+ self.pipe_t2i.scheduler = DPMSolverMultistepScheduler.from_config(
31
+ self.pipe_t2i.scheduler.config
32
+ )
33
+ self.pipe_i2i = StableDiffusionImg2ImgPipeline(**self.pipe_t2i.components)
34
+ else:
35
+ self.pipe_t2i = None
36
+ self.pipe_i2i = None
37
 
38
 
39
  models = [
 
51
 
52
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def error_str(error, title="Error"):
56
  return (
 
60
  else ""
61
  )
62
 
63
+
64
  def inference(
65
  model_name,
66
  prompt,
 
135
  ):
136
  pipe = MODELS[model_name].pipe_t2i
137
 
138
+ if torch.cuda.is_available():
139
+ pipe = pipe.to("cuda")
140
+ pipe.enable_xformers_memory_efficient_attention()
 
 
 
141
 
142
  result = pipe(
143
  prompt,
 
150
  generator=generator,
151
  )
152
 
153
+ pipe.to("cpu")
154
+ torch.cuda.empty_cache()
 
155
 
156
  return replace_nsfw_images(result)
157
 
158
+
159
  def img_to_img(
160
  model_name,
161
  prompt,
 
170
  generator,
171
  seed,
172
  ):
173
+ pipe = MODELS[model_name].pipe_i2i
174
 
175
+ if torch.cuda.is_available():
176
+ pipe = pipe.to("cuda")
177
+ pipe.enable_xformers_memory_efficient_attention()
 
 
 
178
 
179
  ratio = min(height / img.height, width / img.width)
180
  img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS)
 
190
  generator=generator,
191
  )
192
 
193
+ pipe.to("cpu")
194
+ torch.cuda.empty_cache()
 
195
 
196
  return replace_nsfw_images(result)
197
 
198
+
199
  def replace_nsfw_images(results):
200
  for i in range(len(results.images)):
201
  if results.nsfw_content_detected[i]:
202
  results.images[i] = Image.open("nsfw.png")
203
  return results.images
204
 
205
+
206
  with gr.Blocks(css="style.css") as demo:
207
  gr.HTML(
208
  f"""