Spaces:
Runtime error
Runtime error
rewrite ui
Browse files
app.py
CHANGED
|
@@ -9,20 +9,18 @@ from src.pipeline import FashionPipeline, PipelineOutput
|
|
| 9 |
|
| 10 |
|
| 11 |
config = PipelineConfig()
|
| 12 |
-
fashion_pipeline = FashionPipeline(config, device=torch.device('
|
| 13 |
|
| 14 |
|
| 15 |
def process(input_image: np.ndarray, prompt: str):
|
| 16 |
-
|
| 17 |
output: PipelineOutput = fashion_pipeline(
|
| 18 |
control_image=input_image,
|
| 19 |
prompt=prompt,
|
| 20 |
)
|
| 21 |
|
| 22 |
return [
|
| 23 |
-
output.control_image,
|
| 24 |
-
output.segmentation_mask,
|
| 25 |
output.generated_image,
|
|
|
|
| 26 |
]
|
| 27 |
|
| 28 |
|
|
@@ -60,9 +58,11 @@ with block:
|
|
| 60 |
""")
|
| 61 |
|
| 62 |
with gr.Column():
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
ips = [input_image, prompt]
|
| 65 |
-
run_button.click(fn=process, inputs=ips, outputs=[
|
| 66 |
|
| 67 |
|
| 68 |
block.launch()
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
config = PipelineConfig()
|
| 12 |
+
fashion_pipeline = FashionPipeline(config, device=torch.device('cuda'))
|
| 13 |
|
| 14 |
|
| 15 |
def process(input_image: np.ndarray, prompt: str):
|
|
|
|
| 16 |
output: PipelineOutput = fashion_pipeline(
|
| 17 |
control_image=input_image,
|
| 18 |
prompt=prompt,
|
| 19 |
)
|
| 20 |
|
| 21 |
return [
|
|
|
|
|
|
|
| 22 |
output.generated_image,
|
| 23 |
+
output.segmentation_mask,
|
| 24 |
]
|
| 25 |
|
| 26 |
|
|
|
|
| 58 |
""")
|
| 59 |
|
| 60 |
with gr.Column():
|
| 61 |
+
generated_output = gr.Image(label="Generated", type="numpy", elem_id="generated")
|
| 62 |
+
mask_output = gr.Image(label="Mask", type="numpy", elem_id="mask")
|
| 63 |
+
|
| 64 |
ips = [input_image, prompt]
|
| 65 |
+
run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])
|
| 66 |
|
| 67 |
|
| 68 |
block.launch()
|