code modifications to add the hyperlinks to model and dataset

#3
Files changed (2) hide show
  1. app.py +78 -272
  2. requirements.txt +1 -3
app.py CHANGED
@@ -5,7 +5,7 @@ from flax.training.common_utils import shard
5
  from PIL import Image
6
  from argparse import Namespace
7
  import gradio as gr
8
- import copy # added
9
  import numpy as np
10
  import mediapipe as mp
11
  from mediapipe import solutions
@@ -13,64 +13,44 @@ from mediapipe.framework.formats import landmark_pb2
13
  from mediapipe.tasks import python
14
  from mediapipe.tasks.python import vision
15
  import cv2
16
- import psutil
17
- from gpuinfo import GPUInfo
18
- import time
19
- import gc
20
- import torch
21
 
22
  from diffusers import (
23
  FlaxControlNetModel,
24
  FlaxStableDiffusionControlNetPipeline,
25
  )
26
- right_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
27
- left_style_lm = copy.deepcopy(solutions.drawing_styles.get_default_hand_landmarks_style())
28
- right_style_lm[0].color=(251, 206, 177)
29
- left_style_lm[0].color=(255, 255, 225)
30
-
31
- def draw_landmarks_on_image(rgb_image, detection_result, overlap=False, hand_encoding=False):
32
- hand_landmarks_list = detection_result.hand_landmarks
33
- handedness_list = detection_result.handedness
34
- if overlap:
35
- annotated_image = np.copy(rgb_image)
36
- else:
37
- annotated_image = np.zeros_like(rgb_image)
38
 
39
- # Loop through the detected hands to visualize.
40
- for idx in range(len(hand_landmarks_list)):
41
- hand_landmarks = hand_landmarks_list[idx]
42
- handedness = handedness_list[idx]
43
- # Draw the hand landmarks.
44
- hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
45
- hand_landmarks_proto.landmark.extend([
46
- landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
47
- ])
48
- if hand_encoding:
49
- if handedness[0].category_name == "Left":
50
- solutions.drawing_utils.draw_landmarks(
51
- annotated_image,
52
- hand_landmarks_proto,
53
- solutions.hands.HAND_CONNECTIONS,
54
- left_style_lm,
55
- solutions.drawing_styles.get_default_hand_connections_style())
56
- if handedness[0].category_name == "Right":
57
- solutions.drawing_utils.draw_landmarks(
58
- annotated_image,
59
- hand_landmarks_proto,
60
- solutions.hands.HAND_CONNECTIONS,
61
- right_style_lm,
62
- solutions.drawing_styles.get_default_hand_connections_style())
63
- else:
64
- solutions.drawing_utils.draw_landmarks(
65
- annotated_image,
66
- hand_landmarks_proto,
67
- solutions.hands.HAND_CONNECTIONS,
68
- solutions.drawing_styles.get_default_hand_landmarks_style(),
69
- solutions.drawing_styles.get_default_hand_connections_style())
70
-
71
- return annotated_image
72
 
73
- def generate_annotation(img, overlap=False, hand_encoding=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  """img(input): numpy array
75
  annotated_image(output): numpy array
76
  """
@@ -88,260 +68,91 @@ def generate_annotation(img, overlap=False, hand_encoding=False):
88
  detection_result = detector.detect(image)
89
 
90
  # STEP 5: Process the classification result. In this case, visualize it.
91
- annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result, overlap=overlap, hand_encoding=hand_encoding)
92
  return annotated_image
93
 
94
-
95
-
96
- std_args = Namespace(
97
- pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
98
- revision="non-ema",
99
- from_pt=True,
100
- controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
101
- controlnet_revision=None,
102
- controlnet_from_pt=False,
103
- )
104
- enc_args = Namespace(
105
- pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
106
- revision="non-ema",
107
- from_pt=True,
108
- controlnet_model_name_or_path="MakiPan/controlnet-encoded-hands-130k",
109
- controlnet_revision=None,
110
- controlnet_from_pt=False,
111
- )
112
-
113
- std_controlnet, std_controlnet_params = FlaxControlNetModel.from_pretrained(
114
- std_args.controlnet_model_name_or_path,
115
- revision=std_args.controlnet_revision,
116
- from_pt=std_args.controlnet_from_pt,
117
- dtype=jnp.float32, # jnp.bfloat16
118
- )
119
- enc_controlnet, enc_controlnet_params = FlaxControlNetModel.from_pretrained(
120
- enc_args.controlnet_model_name_or_path,
121
- revision=enc_args.controlnet_revision,
122
- from_pt=enc_args.controlnet_from_pt,
123
- dtype=jnp.float32, # jnp.bfloat16
124
  )
125
 
126
-
127
-
128
- std_pipeline, std_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
129
- std_args.pretrained_model_name_or_path,
130
- # tokenizer=tokenizer,
131
- controlnet=std_controlnet,
132
- safety_checker=None,
133
  dtype=jnp.float32, # jnp.bfloat16
134
- revision=std_args.revision,
135
- from_pt=std_args.from_pt,
136
  )
137
- enc_pipeline, enc_pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
138
- enc_args.pretrained_model_name_or_path,
 
139
  # tokenizer=tokenizer,
140
- controlnet=enc_controlnet,
141
  safety_checker=None,
142
  dtype=jnp.float32, # jnp.bfloat16
143
- revision=enc_args.revision,
144
- from_pt=enc_args.from_pt,
145
  )
146
 
147
 
148
- std_pipeline_params["controlnet"] = std_controlnet_params
149
- std_pipeline_params = jax_utils.replicate(std_pipeline_params)
150
-
151
- enc_pipeline_params["controlnet"] = enc_controlnet_params
152
- enc_pipeline_params = jax_utils.replicate(enc_pipeline_params)
153
 
154
  rng = jax.random.PRNGKey(0)
155
  num_samples = jax.device_count()
156
  prng_seed = jax.random.split(rng, jax.device_count())
157
- memory = psutil.virtual_memory()
158
 
159
- def infer(prompt, negative_prompt, image, model_type="Standard"):
160
- time_start = time.time()
161
  prompts = num_samples * [prompt]
162
- if model_type=="Standard":
163
- prompt_ids = std_pipeline.prepare_text_inputs(prompts)
164
- elif model_type=="Hand Encoding":
165
- prompt_ids = enc_pipeline.prepare_text_inputs(prompts)
166
- else:
167
- pass
168
  prompt_ids = shard(prompt_ids)
169
 
170
- if model_type=="Standard":
171
- annotated_image = generate_annotation(image, overlap=False, hand_encoding=False)
172
- overlap_image = generate_annotation(image, overlap=True, hand_encoding=False)
173
- elif model_type=="Hand Encoding":
174
- annotated_image = generate_annotation(image, overlap=False, hand_encoding=True)
175
- overlap_image = generate_annotation(image, overlap=True, hand_encoding=True)
176
-
177
- else:
178
- pass
179
  validation_image = Image.fromarray(annotated_image).convert("RGB")
 
 
180
 
181
- if model_type=="Standard":
182
- processed_image = std_pipeline.prepare_image_inputs(num_samples * [validation_image])
183
- processed_image = shard(processed_image)
184
 
185
- negative_prompt_ids = std_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
186
- negative_prompt_ids = shard(negative_prompt_ids)
 
 
 
 
 
 
 
187
 
188
- images = std_pipeline(
189
- prompt_ids=prompt_ids,
190
- image=processed_image,
191
- params=std_pipeline_params,
192
- prng_seed=prng_seed,
193
- num_inference_steps=50,
194
- neg_prompt_ids=negative_prompt_ids,
195
- jit=True,
196
- ).images
197
- elif model_type=="Hand Encoding":
198
- processed_image = enc_pipeline.prepare_image_inputs(num_samples * [validation_image])
199
- processed_image = shard(processed_image)
200
 
201
- negative_prompt_ids = enc_pipeline.prepare_text_inputs([negative_prompt] * num_samples)
202
- negative_prompt_ids = shard(negative_prompt_ids)
203
-
204
- images = enc_pipeline(
205
- prompt_ids=prompt_ids,
206
- image=processed_image,
207
- params=enc_pipeline_params,
208
- prng_seed=prng_seed,
209
- num_inference_steps=50,
210
- neg_prompt_ids=negative_prompt_ids,
211
- jit=True,
212
- ).images
213
-
214
- else:
215
- pass
216
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
217
 
218
  results = [i for i in images]
219
-
220
- # running info
221
- time_end = time.time()
222
- time_diff = time_end - time_start
223
- gc.collect()
224
- torch.cuda.empty_cache()
225
- memory = psutil.virtual_memory()
226
- gpu_utilization, gpu_memory = GPUInfo.gpu_usage()
227
- gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0
228
- gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0
229
- system_info = f"""
230
- *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.*
231
- *Processing time: {time_diff:.5} seconds.*
232
- *GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.*
233
- """
234
- return [overlap_image, annotated_image] + results, system_info
235
 
236
 
237
  with gr.Blocks(theme='gradio/soft') as demo:
238
  gr.Markdown("## Stable Diffusion with Hand Control")
239
  gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.")
240
- with gr.Box():
241
- gr.Markdown("""<h2><b>Summary 📋</b></h2>""")
242
- with gr.Accordion("Detail information", open=False):
243
- gr.Markdown("""
244
- As Stable diffusion and other diffusion models are notoriously poor at generating realistic hands for our project we decided to train a ControlNet model using MediaPipes landmarks in order to generate more realistic hands avoiding common issues such as unrealistic positions and irregular digits.
245
- <br>
246
- We opted to use the [HAnd Gesture Recognition Image Dataset](https://github.com/hukenovs/hagrid) (HaGRID) and [MediaPipe's Hand Landmarker](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker) to train a control net that could potentially be used independently or as an in-painting tool.
247
- To preprocess the data there were three options we considered:
248
- <ul>
249
- <li>The first was to use Mediapipes built-in draw landmarks function. This was an obvious first choice however we noticed with low training steps that the model couldn't easily distinguish handedness and would often generate the wrong hand for the conditioning image.</li>
250
- <center>
251
- <table><tr>
252
- <td>
253
- <p align="center" style="padding: 10px">
254
- <img alt="Forwarding" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid250k-blip2/--/MakiPan--hagrid250k-blip2/train/29/image/image.jpg" width="200">
255
- <br>
256
- <em style="color: grey">Original Image</em>
257
- </p>
258
- </td>
259
- <td>
260
- <p align="center">
261
- <img alt="Routing" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid250k-blip2/--/MakiPan--hagrid250k-blip2/train/29/conditioning_image/image.jpg" width="200">
262
- <br>
263
- <em style="color: grey">Conditioning Image</em>
264
- </p>
265
- </td>
266
- </tr></table>
267
- </center>
268
- <li>To counter this issue we changed the palm landmark colors with the intention to keep the color similar in order to learn that they provide similar information, but different to make the model know which hands were left or right.</li>
269
- <center>
270
- <table><tr>
271
- <td>
272
- <p align="center" style="padding: 10px">
273
- <img alt="Forwarding" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid-hand-enc-250k/--/MakiPan--hagrid-hand-enc-250k/train/96/image/image.jpg" width="200">
274
- <br>
275
- <em style="color: grey">Original Image</em>
276
- </p>
277
- </td>
278
- <td>
279
- <p align="center">
280
- <img alt="Routing" src="https://datasets-server.huggingface.co/assets/MakiPan/hagrid-hand-enc-250k/--/MakiPan--hagrid-hand-enc-250k/train/96/conditioning_image/image.jpg" width="200">
281
- <br>
282
- <em style="color: grey">Conditioning Image</em>
283
- </p>
284
- </td>
285
- </tr></table>
286
- </center>
287
- <li>The last option was to use <a href="https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html">MediaPipe Holistic</a> to provide pose face and hand landmarks to the ControlNet. This method was promising in theory, however, the HaGRID dataset was not suitable for this method as the Holistic model performs poorly with partial body and obscurely cropped images.</li>
288
- </ul>
289
- We anecdotally determined that when trained at lower steps the encoded hand model performed better than the standard MediaPipe model due to implied handedness. We theorize that with a larger dataset of more full-body hand and pose classifications, Holistic landmarks will provide the best images in the future however for the moment the hand-encoded model performs best.
290
- """)
291
-
292
- # Information links
293
- with gr.Box():
294
- gr.Markdown("""<h2><b>Links 🔗</b></h2>""")
295
- with gr.Accordion("Models 🚀", open=False):
296
- gr.Markdown("""
297
- <h4><a href="https://huggingface.co/Vincent-luo/controlnet-hands">Standard Model</a></h4>
298
- <h4> <a href="https://huggingface.co/MakiPan/controlnet-encoded-hands-130k/">Model using Hand Encoding</a></h4>
299
- """)
300
-
301
- with gr.Accordion("Datasets 💾", open=False):
302
- gr.Markdown("""
303
- <h4> <a href="https://huggingface.co/datasets/MakiPan/hagrid250k-blip2">Dataset for Standard Model</a></h4>
304
- <h4> <a href="https://huggingface.co/datasets/MakiPan/hagrid-hand-enc-250k">Dataset for Hand Encoding Model</a></h4>
305
- """)
306
-
307
- with gr.Accordion("Preprocessing Scripts 📑", open=False):
308
- gr.Markdown("""
309
- <h4> <a href="https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/normal-preprocessing.py">Standard Data Preprocessing Script</a></h4>
310
- <h4> <a href="https://github.com/Maki-DS/Jax-Controlnet-hand-training/blob/main/Hand-encoded-preprocessing.py">Hand Encoding Data Preprocessing Script</a></h4></center>
311
- """)
312
-
313
- # How to use model
314
- with gr.Box():
315
- gr.Markdown("""<h2><b>How to use ⌛️</b></h2>""")
316
- with gr.Accordion("Generate image with ControlnetHand", open=True):
317
- gr.Markdown("""
318
- - Step 1. Select preprocessing method (Standard or Hand encoding)
319
- - Step 2. Describe the image you want to create along with the hand details of the uploaded or captured image
320
- - Step 3. Provide a negative prompt that helps the model not to create redundant details
321
- - Step 4. Upload or capture by webcam a clear image of hands that are prominently visible in the foreground
322
- - Step 5. Submit and enjoy
323
- """)
324
-
325
- # Model input parameters
326
- model_type = gr.Radio(["Standard", "Hand Encoding"], value="Standard", label="Model preprocessing", info="We developed two models, one with standard MediaPipe landmarks, and one with different (but similar) coloring on palm landmarks to distinguish left and right")
327
 
328
  with gr.Row():
329
  with gr.Column():
330
  prompt_input = gr.Textbox(label="Prompt")
331
  negative_prompt = gr.Textbox(label="Negative Prompt")
332
- with gr.Box():
333
- with gr.Tab("Upload Image"):
334
- upload_image = gr.Image(label="Upload Image", source="upload")
335
- with gr.Tab("Webcam"):
336
- webcam_image = gr.Image(label="Webcam", source="webcam")
337
  # output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
338
  submit_btn = gr.Button(value = "Submit")
339
  # inputs = [prompt_input, negative_prompt, input_image]
340
  # submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
341
- system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*")
342
  with gr.Column():
343
  output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto')
344
-
345
  gr.Examples(
346
  examples=[
347
  [
@@ -370,18 +181,13 @@ with gr.Blocks(theme='gradio/soft') as demo:
370
  "example4.png"
371
  ],
372
  ],
373
- inputs=[prompt_input, negative_prompt, upload_image, model_type],
374
- outputs=[output_image, system_info],
375
  fn=infer,
376
  cache_examples=True,
377
  )
378
- # check source of image
379
- if upload_image and webcam_image is None:
380
- input_image = upload_image
381
- else:
382
- input_image = webcam_image
383
-
384
- inputs = [prompt_input, negative_prompt, input_image, model_type]
385
- submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image, system_info])
386
 
387
  demo.launch()
 
5
  from PIL import Image
6
  from argparse import Namespace
7
  import gradio as gr
8
+
9
  import numpy as np
10
  import mediapipe as mp
11
  from mediapipe import solutions
 
13
  from mediapipe.tasks import python
14
  from mediapipe.tasks.python import vision
15
  import cv2
 
 
 
 
 
16
 
17
  from diffusers import (
18
  FlaxControlNetModel,
19
  FlaxStableDiffusionControlNetPipeline,
20
  )
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # mediapipe annotation
24
+ MARGIN = 10 # pixels
25
+ FONT_SIZE = 1
26
+ FONT_THICKNESS = 1
27
+ HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green
28
+
29
+ def draw_landmarks_on_image(rgb_image, detection_result):
30
+ hand_landmarks_list = detection_result.hand_landmarks
31
+ handedness_list = detection_result.handedness
32
+ annotated_image = np.zeros_like(rgb_image)
33
+
34
+ # Loop through the detected hands to visualize.
35
+ for idx in range(len(hand_landmarks_list)):
36
+ hand_landmarks = hand_landmarks_list[idx]
37
+ handedness = handedness_list[idx]
38
+
39
+ # Draw the hand landmarks.
40
+ hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
41
+ hand_landmarks_proto.landmark.extend([
42
+ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
43
+ ])
44
+ solutions.drawing_utils.draw_landmarks(
45
+ annotated_image,
46
+ hand_landmarks_proto,
47
+ solutions.hands.HAND_CONNECTIONS,
48
+ solutions.drawing_styles.get_default_hand_landmarks_style(),
49
+ solutions.drawing_styles.get_default_hand_connections_style())
50
+
51
+ return annotated_image
52
+
53
+ def generate_annotation(img):
54
  """img(input): numpy array
55
  annotated_image(output): numpy array
56
  """
 
68
  detection_result = detector.detect(image)
69
 
70
  # STEP 5: Process the classification result. In this case, visualize it.
71
+ annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result)
72
  return annotated_image
73
 
74
+ args = Namespace(
75
+ pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
76
+ revision="non-ema",
77
+ from_pt=True,
78
+ controlnet_model_name_or_path="Vincent-luo/controlnet-hands",
79
+ controlnet_revision=None,
80
+ controlnet_from_pt=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
 
83
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
84
+ args.controlnet_model_name_or_path,
85
+ revision=args.controlnet_revision,
86
+ from_pt=args.controlnet_from_pt,
 
 
 
87
  dtype=jnp.float32, # jnp.bfloat16
 
 
88
  )
89
+
90
+ pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
91
+ args.pretrained_model_name_or_path,
92
  # tokenizer=tokenizer,
93
+ controlnet=controlnet,
94
  safety_checker=None,
95
  dtype=jnp.float32, # jnp.bfloat16
96
+ revision=args.revision,
97
+ from_pt=args.from_pt,
98
  )
99
 
100
 
101
+ pipeline_params["controlnet"] = controlnet_params
102
+ pipeline_params = jax_utils.replicate(pipeline_params)
 
 
 
103
 
104
  rng = jax.random.PRNGKey(0)
105
  num_samples = jax.device_count()
106
  prng_seed = jax.random.split(rng, jax.device_count())
 
107
 
108
+
109
+ def infer(prompt, negative_prompt, image):
110
  prompts = num_samples * [prompt]
111
+ prompt_ids = pipeline.prepare_text_inputs(prompts)
 
 
 
 
 
112
  prompt_ids = shard(prompt_ids)
113
 
114
+ annotated_image = generate_annotation(image)
 
 
 
 
 
 
 
 
115
  validation_image = Image.fromarray(annotated_image).convert("RGB")
116
+ processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
117
+ processed_image = shard(processed_image)
118
 
119
+ negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples)
120
+ negative_prompt_ids = shard(negative_prompt_ids)
 
121
 
122
+ images = pipeline(
123
+ prompt_ids=prompt_ids,
124
+ image=processed_image,
125
+ params=pipeline_params,
126
+ prng_seed=prng_seed,
127
+ num_inference_steps=50,
128
+ neg_prompt_ids=negative_prompt_ids,
129
+ jit=True,
130
+ ).images
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
134
 
135
  results = [i for i in images]
136
+ return [annotated_image] + results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  with gr.Blocks(theme='gradio/soft') as demo:
140
  gr.Markdown("## Stable Diffusion with Hand Control")
141
  gr.Markdown("This model is a ControlNet model using MediaPipe hand landmarks for control.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  with gr.Row():
144
  with gr.Column():
145
  prompt_input = gr.Textbox(label="Prompt")
146
  negative_prompt = gr.Textbox(label="Negative Prompt")
147
+ input_image = gr.Image(label="Input Image")
 
 
 
 
148
  # output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
149
  submit_btn = gr.Button(value = "Submit")
150
  # inputs = [prompt_input, negative_prompt, input_image]
151
  # submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
152
+
153
  with gr.Column():
154
  output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=2, height='auto')
155
+
156
  gr.Examples(
157
  examples=[
158
  [
 
181
  "example4.png"
182
  ],
183
  ],
184
+ inputs=[prompt_input, negative_prompt, input_image],
185
+ outputs=[output_image],
186
  fn=infer,
187
  cache_examples=True,
188
  )
189
+
190
+ inputs = [prompt_input, negative_prompt, input_image]
191
+ submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
 
 
 
 
 
192
 
193
  demo.launch()
requirements.txt CHANGED
@@ -7,6 +7,4 @@ git+https://github.com/huggingface/diffusers@main
7
  opencv-python
8
  torch
9
  torchvision
10
- mediapipe==0.9.1
11
- gpuinfo
12
- psutil
 
7
  opencv-python
8
  torch
9
  torchvision
10
+ mediapipe==0.9.1