mokady commited on
Commit
799c48f
1 Parent(s): 40d97c0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +138 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ hf_token = os.environ.get("HF_TOKEN")
4
+ import spaces
5
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler, AutoencoderKL
6
+ import torch
7
+ import time
8
+
9
+ class Dummy():
10
+ pass
11
+
12
+ resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280" ]
13
+
14
+ # Load pipeline
15
+
16
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
17
+ unet = UNet2DConditionModel.from_pretrained("briaai/BRIA-2.2-FAST", torch_dtype=torch.float16)
18
+ pipe = DiffusionPipeline.from_pretrained("briaai/BRIA-2.2", torch_dtype=torch.float16, unet=unet, vae=vae)
19
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
20
+ pipe.to('cuda')
21
+ del unet
22
+ del vae
23
+
24
+
25
+ pipe.force_zeros_for_empty_prompt = False
26
+
27
+ print("Optimizing BRIA 2.2 FAST - this could take a while")
28
+ t=time.time()
29
+ pipe.unet = torch.compile(
30
+ pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
31
+ )
32
+ with torch.no_grad():
33
+ outputs = pipe(
34
+ prompt="an apple",
35
+ num_inference_steps=8,
36
+ )
37
+
38
+ # This will avoid future compilations on different shapes
39
+ unet_compiled = torch._dynamo.run(pipe.unet)
40
+ unet_compiled.config=pipe.unet.config
41
+ unet_compiled.add_embedding = Dummy()
42
+ unet_compiled.add_embedding.linear_1 = Dummy()
43
+ unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
44
+ pipe.unet = unet_compiled
45
+
46
+ print(f"Optimizing finished successfully after {time.time()-t} secs")
47
+
48
+ @spaces.GPU(enable_queue=True)
49
+ def infer(prompt,seed,resolution):
50
+ print(f"""
51
+ —/n
52
+ {prompt}
53
+ """)
54
+
55
+ # generator = torch.Generator("cuda").manual_seed(555)
56
+ t=time.time()
57
+
58
+ if seed=="-1":
59
+ generator=None
60
+ else:
61
+ try:
62
+ seed=int(seed)
63
+ generator = torch.Generator("cuda").manual_seed(seed)
64
+ except:
65
+ generator=None
66
+
67
+ w,h = resolution.split()
68
+ w,h = int(w),int(h)
69
+ image = pipe(prompt,num_inference_steps=8,generator=generator,width=w,height=h).images[0]
70
+ print(f'gen time is {time.time()-t} secs')
71
+
72
+ # Future
73
+ # Add amound of steps
74
+ # if nsfw:
75
+ # raise gr.Error("Generated image is NSFW")
76
+
77
+ return image
78
+
79
+ css = """
80
+ #col-container{
81
+ margin: 0 auto;
82
+ max-width: 580px;
83
+ }
84
+ """
85
+ with gr.Blocks(css=css) as demo:
86
+ with gr.Column(elem_id="col-container"):
87
+ gr.Markdown("## BRIA 2.2 FAST")
88
+ gr.HTML('''
89
+ <p style="margin-bottom: 10px; font-size: 94%">
90
+ This is a demo for
91
+ <a href="https://huggingface.co/briaai/BRIA-2.2-FAST" target="_blank">BRIA 2.2 FAST </a>.
92
+ This is a fast version of BRIA 2.2 text-to-image model, still trained on licensed data, and so provides full legal liability coverage for copyright and privacy infringement.
93
+ Try it for free in our webapp demo <a href="https://labs.bria.ai/" </a>.
94
+
95
+ Are you a startup or a student? We encourage you to apply for our Startup Plan
96
+ <a href="https://pages.bria.ai/the-visual-generative-ai-platform-for-builders-startups-plan?_gl=1*cqrl81*_ga*MTIxMDI2NzI5OC4xNjk5NTQ3MDAz*_ga_WRN60H46X4*MTcwOTM5OTMzNC4yNzguMC4xNzA5Mzk5MzM0LjYwLjAuMA..) </a>
97
+ This program are designed to support emerging businesses and academic pursuits with our cutting-edge technology.
98
+
99
+ </p>
100
+ ''')
101
+ with gr.Group():
102
+ with gr.Column():
103
+ prompt_in = gr.Textbox(label="Prompt", value="A smiling man with wavy brown hair and a trimmed beard")
104
+ resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
105
+ seed = gr.Textbox(label="Seed", value=-1)
106
+ submit_btn = gr.Button("Generate")
107
+ result = gr.Image(label="BRIA 2.2 FAST Result")
108
+
109
+ # gr.Examples(
110
+ # examples = [
111
+ # "Dragon, digital art, by Greg Rutkowski",
112
+ # "Armored knight holding sword",
113
+ # "A flat roof villa near a river with black walls and huge windows",
114
+ # "A calm and peaceful office",
115
+ # "Pirate guinea pig"
116
+ # ],
117
+ # fn = infer,
118
+ # inputs = [
119
+ # prompt_in
120
+ # ],
121
+ # outputs = [
122
+ # result
123
+ # ]
124
+ # )
125
+
126
+ submit_btn.click(
127
+ fn = infer,
128
+ inputs = [
129
+ prompt_in,
130
+ seed,
131
+ resolution
132
+ ],
133
+ outputs = [
134
+ result
135
+ ]
136
+ )
137
+
138
+ demo.queue().launch(show_api=False)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ torch
4
+ torchvision
5
+ accelerate
6
+ spaces
7
+ gradio