mrbeliever commited on
Commit
a2f8ee0
·
verified ·
1 Parent(s): 56d8f4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -58
app.py CHANGED
@@ -1,12 +1,15 @@
 
1
  import gradio as gr
2
  import re
3
  from PIL import Image
 
 
4
  import torch
5
  from diffusers import FluxImg2ImgPipeline
6
 
7
- # Set up the device and pipeline
8
  dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
  pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
11
 
12
  def sanitize_prompt(prompt):
@@ -18,87 +21,100 @@ def convert_to_fit_size(original_width_and_height, maximum_size=2048):
18
  width, height = original_width_and_height
19
  if width <= maximum_size and height <= maximum_size:
20
  return width, height
21
-
22
  scaling_factor = maximum_size / max(width, height)
23
- return int(width * scaling_factor), int(height * scaling_factor)
 
 
24
 
25
- def adjust_to_multiple_of_32(width, height):
26
- return width - (width % 32), height - (height % 32)
 
 
27
 
 
28
  def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
29
- def process_img2img(image, prompt, strength, seed, num_inference_steps):
 
30
  if image is None:
31
  return None
32
  generator = torch.Generator(device).manual_seed(seed)
33
- width, height = adjust_to_multiple_of_32(*convert_to_fit_size(image.size))
 
34
  image = image.resize((width, height), Image.LANCZOS)
35
- output = pipe(
36
- prompt=prompt,
37
- image=image,
38
- generator=generator,
39
- strength=strength,
40
- width=width,
41
- height=height,
42
- guidance_scale=0,
43
- num_inference_steps=num_inference_steps
44
- )
45
  return output.images[0]
46
-
47
- return process_img2img(image, prompt, strength, seed, inference_step)
 
 
 
 
 
48
 
49
- # Minimal CSS for black outline and container styling
50
  css = """
51
  #demo-container {
52
- border: 2px solid black;
53
- padding: 10px;
54
- width: 100%;
55
- max-width: 750px;
56
- margin: auto;
 
 
 
 
 
 
 
 
57
  }
58
- #image_upload, #output-img, #generate_button {
59
- border: 2px solid black;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
  """
62
 
63
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
64
  with gr.Column():
65
- gr.HTML("<h1 style='text-align:center;'>Image to Image Generation</h1>")
 
 
66
  with gr.Row():
67
  with gr.Column():
68
- image = gr.Image(
69
- height=400,
70
- sources=['upload', 'clipboard'],
71
- image_mode='RGB',
72
- elem_id="image_upload",
73
- type="pil",
74
- label="Upload Image"
75
- )
76
- prompt = gr.Textbox(
77
- label="Prompt",
78
- value="A woman",
79
- placeholder="Describe the output image",
80
- elem_id="prompt"
81
- )
82
- btn = gr.Button("Generate", elem_id="generate_button", variant="primary")
83
  with gr.Accordion(label="Advanced Settings", open=False):
84
- strength = gr.Number(value=0.75, minimum=0, maximum=1, step=0.01, label="Strength")
85
  seed = gr.Number(value=100, minimum=0, step=1, label="Seed")
86
- inference_step = gr.Number(value=4, minimum=1, step=1, label="Inference Steps")
87
  with gr.Column():
88
- image_out = gr.Image(
89
- height=400,
90
- sources=[],
91
- label="Generated Output",
92
- elem_id="output-img",
93
- format="jpg"
94
- )
95
 
96
- btn.click(
97
- process_images,
98
- inputs=[image, prompt, strength, seed, inference_step],
 
99
  outputs=[image_out]
100
  )
101
 
102
- # Enable queue mode and CORS support
103
- demo.queue(concurrency_count=3, cors_allow_origins=["*"])
104
- demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
  import re
4
  from PIL import Image
5
+ import os
6
+ import numpy as np
7
  import torch
8
  from diffusers import FluxImg2ImgPipeline
9
 
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
  pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
14
 
15
  def sanitize_prompt(prompt):
 
21
  width, height = original_width_and_height
22
  if width <= maximum_size and height <= maximum_size:
23
  return width, height
 
24
  scaling_factor = maximum_size / max(width, height)
25
+ new_width = int(width * scaling_factor)
26
+ new_height = int(height * scaling_factor)
27
+ return new_width, new_height
28
 
29
+ def adjust_to_multiple_of_32(width: int, height: int):
30
+ width = width - (width % 32)
31
+ height = height - (height % 32)
32
+ return width, height
33
 
34
+ @spaces.GPU(duration=120)
35
  def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
36
+ progress(0, desc="Starting")
37
+ def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
38
  if image is None:
39
  return None
40
  generator = torch.Generator(device).manual_seed(seed)
41
+ width, height = convert_to_fit_size(image.size)
42
+ width, height = adjust_to_multiple_of_32(width, height)
43
  image = image.resize((width, height), Image.LANCZOS)
44
+ output = pipe(prompt=prompt, image=image, generator=generator, strength=strength, width=width, height=height, guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256)
 
 
 
 
 
 
 
 
 
45
  return output.images[0]
46
+ output = process_img2img(image, prompt, strength, seed, inference_step)
47
+ return output
48
+
49
+ def read_file(path: str) -> str:
50
+ with open(path, 'r', encoding='utf-8') as f:
51
+ content = f.read()
52
+ return content
53
 
 
54
  css = """
55
  #demo-container {
56
+ border: 4px solid black;
57
+ border-radius: 8px;
58
+ padding: 20px;
59
+ margin: 20px auto;
60
+ max-width: 800px;
61
+ }
62
+
63
+ #image_upload, #output-img {
64
+ border: 4px solid black;
65
+ border-radius: 8px;
66
+ width: 256px;
67
+ height: 256px;
68
+ object-fit: cover;
69
  }
70
+
71
+ #run_button {
72
+ font-weight: bold;
73
+ border: 4px solid black;
74
+ border-radius: 8px;
75
+ padding: 10px 20px;
76
+ }
77
+
78
+ #col-left, #col-right {
79
+ max-width: 640px;
80
+ margin: 0 auto;
81
+ }
82
+ .grid-container {
83
+ display: flex;
84
+ align-items: center;
85
+ justify-content: center;
86
+ gap: 10px;
87
+ }
88
+ .text {
89
+ font-size: 16px;
90
  }
91
  """
92
 
93
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
94
  with gr.Column():
95
+ gr.HTML(read_file("demo_header.html"))
96
+ # Removed or commented out the demo_tools.html line
97
+ # gr.HTML(read_file("demo_tools.html"))
98
  with gr.Row():
99
  with gr.Column():
100
+ image = gr.Image(width=256, height=256, sources=['upload', 'clipboard'], image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
101
+ prompt = gr.Textbox(label="Prompt", value="a woman", placeholder="Your prompt", elem_id="prompt")
102
+ btn = gr.Button("Generate", elem_id="run_button", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
103
  with gr.Accordion(label="Advanced Settings", open=False):
104
+ strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="Strength")
105
  seed = gr.Number(value=100, minimum=0, step=1, label="Seed")
106
+ inference_step = gr.Number(value=4, minimum=1, step=4, label="Inference Steps")
107
  with gr.Column():
108
+ image_out = gr.Image(width=256, height=256, label="Output", elem_id="output-img", format="jpg")
109
+
110
+ gr.HTML(gr.HTML(read_file("demo_footer.html")))
 
 
 
 
111
 
112
+ gr.on(
113
+ triggers=[btn.click, prompt.submit],
114
+ fn=process_images,
115
+ inputs=[image, prompt, strength, seed, inference_step],
116
  outputs=[image_out]
117
  )
118
 
119
+ if __name__ == "__main__":
120
+ demo.launch()