File size: 12,831 Bytes
cd1e8dc
628e6c3
cd1e8dc
 
 
 
 
 
 
 
 
 
628e6c3
35f97ba
628e6c3
426fb9c
822b647
426fb9c
 
 
1195790
 
 
 
 
 
ba29a7c
 
 
2ba8aac
ba29a7c
a30c911
2ba8aac
822b647
 
 
 
 
 
 
 
 
 
 
 
 
a30c911
 
 
 
ba29a7c
 
663f236
 
 
 
 
ba29a7c
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bff
cd1e8dc
 
 
 
 
a5e9129
 
 
cd1e8dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35f97ba
 
be5cb04
f2819ff
ce09356
 
ccf1a03
cd1e8dc
 
 
 
 
 
 
 
4598830
cd1e8dc
 
 
4598830
 
 
cd1e8dc
 
 
e57f9cc
cd1e8dc
 
9702a1f
cd1e8dc
 
a5e9129
 
cd1e8dc
 
9702a1f
a5e9129
cd1e8dc
 
 
 
7b66f42
cd1e8dc
 
 
 
 
 
 
 
 
788a013
dc72f49
788a013
 
 
 
cd1e8dc
788a013
 
 
 
cd1e8dc
 
68696f0
71f4cfe
 
5461399
cd1e8dc
 
 
 
 
 
 
71f4cfe
 
cd1e8dc
 
 
71f4cfe
 
 
783c45d
cd1e8dc
 
 
783c45d
cd1e8dc
 
 
a5e9129
cd1e8dc
628e6c3
 
1195790
628e6c3
 
 
 
 
 
cd1e8dc
 
628e6c3
 
 
 
 
cd1e8dc
628e6c3
 
 
 
 
 
 
 
 
 
 
 
cd1e8dc
 
 
 
 
 
628e6c3
 
 
 
 
 
 
 
3f2581f
628e6c3
14811bd
4598830
 
2d0240c
628e6c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba29a7c
 
 
a30c911
bf28d41
f7a5714
e5488f2
f7a5714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14811bd
f7a5714
14811bd
f7a5714
 
ba29a7c
 
 
 
 
 
628e6c3
 
 
cd1e8dc
628e6c3
cd1e8dc
 
 
628e6c3
 
4598830
628e6c3
cd1e8dc
628e6c3
 
 
 
 
 
88bd8ea
 
628e6c3
 
a5e9129
 
 
628e6c3
a5e9129
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
# The original license file is LICENSE.ControlNet in this repo.
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
from flax.training.common_utils import shard
from flax.jax_utils import replicate    
from diffusers.utils import load_image
import jax.numpy as jnp
import jax
import cv2
from PIL import Image
import numpy as np
import gradio as gr
import os


if gr.__version__ != "3.28.3": #doesn't work...
    os.system("pip uninstall -y gradio")
    os.system("pip install gradio==3.28.3")

title_description = """
# SynDRoM
## Synthetic Data augmentation for Robotic Manipulation

"""

description = """
Our project is to use diffusion model to change the texture of our robotic arm simulation.
To do so, we first get our simulated images. After, we process these images to get Canny Edge maps. Finally, we can get brand new images by using ControlNet.
Therefore, we are able to change our simulation texture, and still keep the image composition.


Our objectif for the sprint is to perform data augmentation using ControlNet. So we look for having a model that can augment an image quickly.
To do so, we trained many Controlnets from scratch with different datasets : 
* [Coyo-700M](https://github.com/kakaobrain/coyo-dataset)
* [Bridge](https://sites.google.com/view/bridgedata)

A method to accelerate the inference of diffusion model is by simply generating small images. So we decided to work with low resolution images.
After downloading the datasets, we processed them by resizing images to a 128 resolution. 
The smallest side of the image (width or height) is resized to 128 and the other side is resized keeping the initial ratio.
After, we retrieve the Canny Edge Map of the images. We performed this preprocess for every datasets we use during the sprint. 


We train four different Controlnets. For each one of them, we processed the datasets differently. You can find the description of the processing in the readme file attached to the model repo
[Our ControlNet repo](https://huggingface.co/Baptlem/baptlem-controlnet)

For now, we benchmarked our model on a node of 4 Titan RTX 24Go. We were able to generate a batch of 4 images in a average time of 1.3 seconds!
We also have access to nodes composed of 8 A100 80Go GPUs. The benchmark on one of these nodes will come soon.
 

"""

traj_description = """ 
We generated a trajectory of our simulated environment. We will then use it with our different models.
We made these videos on our Titan RTX node.
The prompt we use for every video is "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background"
"""


def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def load_controlnet(controlnet_version):
    controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
        "Baptlem/baptlem-controlnet",
        subfolder=controlnet_version,
        from_flax=True,
        dtype=jnp.float32,
    )
    return controlnet, controlnet_params


def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
    controlnet, controlnet_params = load_controlnet(controlnet_version)

    scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
        sb_path,
        subfolder="scheduler"
    )
    
    pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
        sb_path,
        controlnet=controlnet, 
        revision="flax", 
        dtype=jnp.bfloat16
    )
        
    pipe.scheduler = scheduler
    params["controlnet"] = controlnet_params
    params["scheduler"] = scheduler_params
    return pipe, params  

    

controlnet_path = "Baptlem/baptlem-controlnet"
controlnet_version = "coyo-500k"

# Constants
low_threshold = 100
high_threshold = 200

print(os.path.abspath('.'))
print(os.listdir("."))
print("Gradio version:", gr.__version__)
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# pipe.enable_attention_slicing()
print("Loaded models...")
def pipe_inference(
    image,
    prompt,
    is_canny=False,
    num_samples=4,
    resolution=128,
    num_inference_steps=50,
    guidance_scale=7.5,
    model="coyo-500k",
    seed=0,
    negative_prompt="",
    ):
    print("Loading pipe")
    pipe, params = load_sb_pipe(model)
        
    if not isinstance(image, np.ndarray):
        image = np.array(image) 

    processed_image = resize_image(image, resolution) #-> PIL
        
    if not is_canny:
        resized_image, processed_image = preprocess_canny(processed_image, resolution)

    rng = create_key(seed)
    rng = jax.random.split(rng, jax.device_count())

    prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
    negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
    processed_image = pipe.prepare_image_inputs([processed_image] * num_samples)
        
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    negative_prompt_ids = shard(negative_prompt_ids)
    processed_image = shard(processed_image)
    print("Inference...")
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        neg_prompt_ids=negative_prompt_ids,
        jit=True,
    ).images
    print("Finished inference...")
    # all_outputs = []
    # all_outputs.append(image)
    # if not is_canny:
    #     all_outputs.append(resized_image)
        
    # for image in output.images:
    #     all_outputs.append(image)

    all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return all_outputs

def resize_image(image, resolution):  
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
    h, w = image.shape[:2]
    ratio = w/h
    if ratio > 1 :
        resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
    elif ratio < 1 :
        resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
    else:
        resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
    
    return Image.fromarray(resized_image)
    
    
def preprocess_canny(image, resolution=128):
    if not isinstance(image, np.ndarray):
        image = np.array(image) 
        
    processed_image = cv2.Canny(image, low_threshold, high_threshold)
    processed_image = processed_image[:, :, None]
    processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)

    resized_image = Image.fromarray(image)
    processed_image = Image.fromarray(processed_image)
    return resized_image, processed_image


def create_demo(process, max_images=12, default_num_images=4):
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown(title_description)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(source='upload', type='numpy')
                prompt = gr.Textbox(label='Prompt')
                run_button = gr.Button(label='Run')
                with gr.Accordion('Advanced options', open=False):
                    is_canny = gr.Checkbox(
                        label='Is canny', value=False)
                    num_samples = gr.Slider(label='Images',
                                            minimum=1,
                                            maximum=max_images,
                                            value=default_num_images,
                                            step=1)
                    """
                    canny_low_threshold = gr.Slider(
                        label='Canny low threshold',
                        minimum=1,
                        maximum=255,
                        value=100,
                        step=1)
                    canny_high_threshold = gr.Slider(
                        label='Canny high threshold',
                        minimum=1,
                        maximum=255,
                        value=200,
                        step=1)
                    """
                    resolution = gr.Slider(label='Resolution',
                                          minimum=128,
                                          maximum=128,
                                          value=128,
                                          step=1)
                    num_steps = gr.Slider(label='Steps',
                                          minimum=1,
                                          maximum=100,
                                          value=20,
                                          step=1)
                    guidance_scale = gr.Slider(label='Guidance Scale',
                                               minimum=0.1,
                                               maximum=30.0,
                                               value=7.5,
                                               step=0.1)
                    model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo2M-bridge325k"],
                                        value="coyo-500k",
                                        label="Model used for inference", 
                                        info="Find every models at https://huggingface.co/Baptlem/baptlem-controlnet")
                    seed = gr.Slider(label='Seed',
                                     minimum=-1,
                                     maximum=2147483647,
                                     step=1,
                                     randomize=True)
                    n_prompt = gr.Textbox(
                        label='Negative Prompt',
                        value=
                        'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
                    )
            with gr.Column():
                result = gr.Gallery(label='Output',
                                    show_label=False,
                                    elem_id='gallery').style(grid=2,
                                                             height='auto')
        
        with gr.Row():
            gr.Markdown(description)

        with gr.Row():
            with gr.Column():
                gr.Markdown(traj_description)
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with coyo-500k model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_coyo-500k.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with bridge-2M model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_bridge-2M.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with coyo1M-bridge2M model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi",
                        format="avi",
                        interactive=False)

        with gr.Row():
            with gr.Column():
                gr.Markdown("Trajectory processed with coyo2M-bridge325k model :")
            with gr.Column():
                gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k.avi",
                        format="avi",
                        interactive=False)
        
        
        
        
        
        
        inputs = [
            input_image,
            prompt,
            is_canny,
            num_samples,
            resolution,
            #canny_low_threshold,
            #canny_high_threshold,
            num_steps,
            guidance_scale,
            model,
            seed,
            n_prompt,
        ]
        prompt.submit(fn=process, inputs=inputs, outputs=result)
        run_button.click(fn=process,
                         inputs=inputs,
                         outputs=result,
                         api_name='canny')
    
    return demo

if __name__ == '__main__':

    pipe_inference
    demo = create_demo(pipe_inference)
    demo.queue().launch()
    # gr.Interface(create_demo).launch()