OriLib commited on
Commit
f3f94c7
1 Parent(s): d29ef6c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -0
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import cv2
3
+ import gradio as gr
4
+ import torch
5
+ from diffusers import (
6
+ AutoencoderKL,
7
+ EulerAncestralDiscreteScheduler,
8
+ )
9
+ from diffusers.utils import load_image
10
+ from replace_bg.model.pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
11
+ from replace_bg.model.controlnet import ControlNetModel
12
+ from replace_bg.utilities import resize_image, remove_bg_from_image, paste_fg_over_image, get_control_image_tensor
13
+
14
+ controlnet = ControlNetModel.from_pretrained("briaai/BRIA-2.3-ControlNet-BG-Gen", torch_dtype=torch.float16)
15
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
16
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet, torch_dtype=torch.float16, vae=vae).to('cuda:0')
17
+ pipe.scheduler = EulerAncestralDiscreteScheduler(
18
+ beta_start=0.00085,
19
+ beta_end=0.012,
20
+ beta_schedule="scaled_linear",
21
+ num_train_timesteps=1000,
22
+ steps_offset=1
23
+ )
24
+
25
+
26
+ @spaces.GPU
27
+ def generate_(prompt, negative_prompt, control_tensor, num_steps, controlnet_conditioning_scale, seed):
28
+ generator = torch.Generator("cuda").manual_seed(seed)
29
+ gen_img = pipe(
30
+ negative_prompt=negative_prompt,
31
+ prompt=prompt,
32
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
33
+ num_inference_steps=num_steps,
34
+ image = control_tensor,
35
+ generator=generator
36
+ ).images[0]
37
+ result_image = paste_fg_over_image(gen_img, image, mask)
38
+ return result_image
39
+
40
+ @spaces.GPU
41
+ def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
42
+
43
+ # resize input_image to 1024x1024
44
+ input_image = resize_image(input_image)
45
+ image = resize_image(image)
46
+ mask = remove_bg_from_image(image_path)
47
+ control_tensor = get_control_image_tensor(pipe.vae, image, mask)
48
+
49
+ images = generate_(prompt, negative_prompt, control_tensor, num_steps, controlnet_conditioning_scale, seed)
50
+
51
+ return [depth_image, images[0]]
52
+
53
+
54
+
55
+ block = gr.Blocks().queue()
56
+
57
+ with block:
58
+ gr.Markdown("## BRIA Generate Background")
59
+ gr.HTML('''
60
+ <p style="margin-bottom: 10px; font-size: 94%">
61
+ This is a demo for ControlNet Depth that using
62
+ <a href="briaai/BRIA-2.3-ControlNet-BG-Gen" target="_blank">BRIA 2.3 text-to-image model</a> as backbone.
63
+ Trained on licensed data, BRIA 2.3 provide full legal liability coverage for copyright and privacy infringement.
64
+ </p>
65
+ ''')
66
+ with gr.Row():
67
+ with gr.Column():
68
+ input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
69
+ prompt = gr.Textbox(label="Prompt")
70
+ negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
71
+ num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
72
+ controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
73
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
74
+ run_button = gr.Button(value="Run")
75
+
76
+
77
+ with gr.Column():
78
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto')
79
+ ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
80
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
81
+
82
+ block.launch(debug = True)