enable better output

#3
by radames HF staff - opened
Files changed (2) hide show
  1. app.py +8 -5
  2. requirements.txt +5 -1
app.py CHANGED
@@ -13,6 +13,8 @@ current_steps = 15
13
 
14
  pipe = DiffusionPipeline.from_pretrained("timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None)
15
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
 
 
16
 
17
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
18
 
@@ -62,9 +64,9 @@ def inference(
62
  )
63
 
64
  # return replace_nsfw_images(result)
65
- return result.images, f"Done. Seed: {seed}"
66
  except Exception as e:
67
- return None, error_str(e)
68
 
69
 
70
  def replace_nsfw_images(results):
@@ -119,7 +121,8 @@ with gr.Blocks(css="style.css") as demo:
119
  state_info = gr.Textbox(label="State", show_label=False, max_lines=2).style(
120
  container=False
121
  )
122
- error_output = gr.Markdown()
 
123
 
124
  with gr.Column(scale=45):
125
  with gr.Tab("Options"):
@@ -180,7 +183,7 @@ with gr.Blocks(css="style.css") as demo:
180
  height,
181
  seed,
182
  ]
183
- outputs = [gallery, error_output]
184
  prompt.submit(inference, inputs=inputs, outputs=outputs)
185
  generate.click(inference, inputs=inputs, outputs=outputs)
186
 
@@ -197,4 +200,4 @@ with gr.Blocks(css="style.css") as demo:
197
  print(f"Space built in {time.time() - start_time:.2f} seconds")
198
 
199
  demo.queue(concurrency_count=1)
200
- demo.launch()
 
13
 
14
  pipe = DiffusionPipeline.from_pretrained("timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None)
15
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
16
+ pipe.enable_xformers_memory_efficient_attention()
17
+ pipe.unet.to(memory_format=torch.channels_last)
18
 
19
  device = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
20
 
 
64
  )
65
 
66
  # return replace_nsfw_images(result)
67
+ return result.images, result.nsfw_content_detected, seed
68
  except Exception as e:
69
+ return None, None, error_str(e)
70
 
71
 
72
  def replace_nsfw_images(results):
 
121
  state_info = gr.Textbox(label="State", show_label=False, max_lines=2).style(
122
  container=False
123
  )
124
+ nsfw_output = gr.JSON()
125
+ error_output = gr.JSON()
126
 
127
  with gr.Column(scale=45):
128
  with gr.Tab("Options"):
 
183
  height,
184
  seed,
185
  ]
186
+ outputs = [gallery, nsfw_output, error_output]
187
  prompt.submit(inference, inputs=inputs, outputs=outputs)
188
  generate.click(inference, inputs=inputs, outputs=outputs)
189
 
 
200
  print(f"Space built in {time.time() - start_time:.2f} seconds")
201
 
202
  demo.queue(concurrency_count=1)
203
+ demo.launch(debug=True, show_api=False)
requirements.txt CHANGED
@@ -7,4 +7,8 @@ scipy
7
  ftfy
8
  psutil
9
  accelerate
10
- safetensors
 
 
 
 
 
7
  ftfy
8
  psutil
9
  accelerate
10
+ safetensors
11
+ transformers
12
+ safetensors
13
+ --pre
14
+ xformers