gk8686 commited on
Commit
f8ebdc7
1 Parent(s): 3df1448

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -38
app.py CHANGED
@@ -14,36 +14,27 @@ from matplotlib import pyplot as plt
14
  from torchvision import transforms
15
  from diffusers import DiffusionPipeline
16
  from diffusers.utils import torch_device
 
 
17
  pipe = DiffusionPipeline.from_pretrained(
18
  "Fantasy-Studio/Paint-by-Example",
19
- torch_dtype=torch.float16,
20
  )
21
- pipe = pipe.to("cuda")
22
-
23
- from share_btn import community_icon_html, loading_icon_html, share_js
24
-
25
- def read_content(file_path: str) -> str:
26
- """read the content of target file
27
- """
28
- with open(file_path, 'r', encoding='utf-8') as f:
29
- content = f.read()
30
-
31
- return content
32
 
 
33
  def predict(dict, reference, scale, seed, step):
34
- width,height=dict["image"].size
35
- if width<height:
36
- factor=width/512.0
37
- width=512
38
- height=int((height/factor)/8.0)*8
39
-
40
  else:
41
- factor=height/512.0
42
- height=512
43
- width=int((width/factor)/8.0)*8
44
- init_image = dict["image"].convert("RGB").resize((width,height))
45
- mask = dict["mask"].convert("RGB").resize((width,height))
46
- generator = torch.Generator('cuda').manual_seed(seed) if seed != 0 else None
47
  output = pipe(
48
  image=init_image,
49
  mask_image=mask,
@@ -52,9 +43,12 @@ def predict(dict, reference, scale, seed, step):
52
  guidance_scale=scale,
53
  num_inference_steps=step,
54
  ).images[0]
55
- return output, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
 
 
56
 
57
 
 
58
  css = '''
59
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
60
  #image_upload{min-height:400px}
@@ -93,15 +87,28 @@ css = '''
93
  display: none !important;
94
  }
95
  '''
96
- example={}
97
- ref_dir='examples/reference'
98
- image_dir='examples/image'
99
- ref_list=[os.path.join(ref_dir,file) for file in os.listdir(ref_dir)]
 
 
 
 
 
 
 
 
 
 
 
 
100
  ref_list.sort()
101
- image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir)]
102
  image_list.sort()
103
 
104
 
 
105
  image_blocks = gr.Blocks(css=css)
106
  with image_blocks as demo:
107
  gr.HTML(read_content("header.html"))
@@ -114,8 +121,8 @@ with image_blocks as demo:
114
 
115
  with gr.Column():
116
  image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
117
- guidance = gr.Slider(label="Guidance scale", value=5, maximum=15,interactive=True)
118
- steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1,interactive=True)
119
 
120
  seed = gr.Slider(0, 10000, label='Seed (0 = random)', value=0, step=1)
121
 
@@ -129,19 +136,17 @@ with image_blocks as demo:
129
  community_icon = gr.HTML(community_icon_html, visible=True)
130
  loading_icon = gr.HTML(loading_icon_html, visible=True)
131
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
132
-
133
-
134
  with gr.Row():
135
  with gr.Column():
136
  gr.Examples(image_list, inputs=[image],label="Examples - Source Image",examples_per_page=12)
137
  with gr.Column():
138
  gr.Examples(ref_list, inputs=[reference],label="Examples - Reference Image",examples_per_page=12)
139
-
140
  btn.click(fn=predict, inputs=[image, reference, guidance, seed, steps], outputs=[image_out, community_icon, loading_icon, share_button])
141
  share_button.click(None, [], [], _js=share_js)
142
 
143
-
144
-
145
  gr.HTML(
146
  """
147
  <div class="footer">
@@ -154,4 +159,5 @@ with image_blocks as demo:
154
  """
155
  )
156
 
157
- image_blocks.launch()
 
 
14
  from torchvision import transforms
15
  from diffusers import DiffusionPipeline
16
  from diffusers.utils import torch_device
17
+
18
+ # Load the model
19
  pipe = DiffusionPipeline.from_pretrained(
20
  "Fantasy-Studio/Paint-by-Example",
21
+ torch_dtype=torch.float32, # Change to float32 for CPU
22
  )
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # Define function to predict
25
  def predict(dict, reference, scale, seed, step):
26
+ width, height = dict["image"].size
27
+ if width < height:
28
+ factor = width / 512.0
29
+ width = 512
30
+ height = int((height / factor) / 8.0) * 8
 
31
  else:
32
+ factor = height / 512.0
33
+ height = 512
34
+ width = int((width / factor) / 8.0) * 8
35
+ init_image = dict["image"].convert("RGB").resize((width, height))
36
+ mask = dict["mask"].convert("RGB").resize((width, height))
37
+ generator = torch.Generator().manual_seed(seed) if seed != 0 else None
38
  output = pipe(
39
  image=init_image,
40
  mask_image=mask,
 
43
  guidance_scale=scale,
44
  num_inference_steps=step,
45
  ).images[0]
46
+ return output, gr.update(visible=True), gr.update(visible=True), gr.update(
47
+ visible=True
48
+ )
49
 
50
 
51
+ # Define CSS
52
  css = '''
53
  .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
54
  #image_upload{min-height:400px}
 
87
  display: none !important;
88
  }
89
  '''
90
+
91
+ # Read content function
92
+ def read_content(file_path: str) -> str:
93
+ """read the content of target file
94
+ """
95
+ with open(file_path, 'r', encoding='utf-8') as f:
96
+ content = f.read()
97
+
98
+ return content
99
+
100
+
101
+ # Define example data
102
+ example = {}
103
+ ref_dir = 'examples/reference'
104
+ image_dir = 'examples/image'
105
+ ref_list = [os.path.join(ref_dir, file) for file in os.listdir(ref_dir)]
106
  ref_list.sort()
107
+ image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
108
  image_list.sort()
109
 
110
 
111
+ # Create Gradio Blocks instance
112
  image_blocks = gr.Blocks(css=css)
113
  with image_blocks as demo:
114
  gr.HTML(read_content("header.html"))
 
121
 
122
  with gr.Column():
123
  image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
124
+ guidance = gr.Slider(label="Guidance scale", value=5, maximum=15, interactive=True)
125
+ steps = gr.Slider(label="Steps", value=50, minimum=2, maximum=75, step=1, interactive=True)
126
 
127
  seed = gr.Slider(0, 10000, label='Seed (0 = random)', value=0, step=1)
128
 
 
136
  community_icon = gr.HTML(community_icon_html, visible=True)
137
  loading_icon = gr.HTML(loading_icon_html, visible=True)
138
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=True)
139
+
140
+
141
  with gr.Row():
142
  with gr.Column():
143
  gr.Examples(image_list, inputs=[image],label="Examples - Source Image",examples_per_page=12)
144
  with gr.Column():
145
  gr.Examples(ref_list, inputs=[reference],label="Examples - Reference Image",examples_per_page=12)
146
+
147
  btn.click(fn=predict, inputs=[image, reference, guidance, seed, steps], outputs=[image_out, community_icon, loading_icon, share_button])
148
  share_button.click(None, [], [], _js=share_js)
149
 
 
 
150
  gr.HTML(
151
  """
152
  <div class="footer">
 
159
  """
160
  )
161
 
162
+ # Launch the Gradio interface
163
+ image_blocks.launch()