Spaces:
briaai
/
Running on Zero

Eyalgut commited on
Commit
818ca9e
1 Parent(s): 1510fcd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -21
app.py CHANGED
@@ -23,29 +23,29 @@ scheduler = EulerAncestralDiscreteScheduler(
23
  )
24
  pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16,scheduler=scheduler).to("cuda")
25
 
26
- print("Optimizing BRIA-2.2 - this could take a while")
27
- t=time.time()
28
- pipe.unet = torch.compile(
29
- pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
30
- )
31
- with torch.no_grad():
32
- outputs = pipe(
33
- prompt="an apple",
34
- num_inference_steps=30,
35
- )
36
 
37
- # This will avoid future compilations on different shapes
38
- unet_compiled = torch._dynamo.run(pipe.unet)
39
- unet_compiled.config=pipe.unet.config
40
- unet_compiled.add_embedding = Dummy()
41
- unet_compiled.add_embedding.linear_1 = Dummy()
42
- unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
43
- pipe.unet = unet_compiled
44
 
45
- print(f"Optimizing finished successfully after {time.time()-t} secs")
46
 
47
  @spaces.GPU(enable_queue=True)
48
- def infer(prompt):
49
  print(f"""
50
  —/n
51
  {prompt}
@@ -53,7 +53,16 @@ def infer(prompt):
53
 
54
  # generator = torch.Generator("cuda").manual_seed(555)
55
  t=time.time()
56
- image = pipe(prompt,num_inference_steps=30, negative_prompt=default_negative_prompt).images[0]
 
 
 
 
 
 
 
 
 
57
  print(f'gen time is {time.time()-t} secs')
58
 
59
  # Future
@@ -82,6 +91,9 @@ with gr.Blocks(css=css) as demo:
82
  with gr.Group():
83
  with gr.Column():
84
  prompt_in = gr.Textbox(label="Prompt", value="A red colored sports car")
 
 
 
85
  submit_btn = gr.Button("Generate")
86
  result = gr.Image(label="BRIA-2.2 Result")
87
 
@@ -105,7 +117,10 @@ with gr.Blocks(css=css) as demo:
105
  submit_btn.click(
106
  fn = infer,
107
  inputs = [
108
- prompt_in
 
 
 
109
  ],
110
  outputs = [
111
  result
 
23
  )
24
  pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16,scheduler=scheduler).to("cuda")
25
 
26
+ # print("Optimizing BRIA-2.2 - this could take a while")
27
+ # t=time.time()
28
+ # pipe.unet = torch.compile(
29
+ # pipe.unet, mode="reduce-overhead", fullgraph=True # 600 secs compilation
30
+ # )
31
+ # with torch.no_grad():
32
+ # outputs = pipe(
33
+ # prompt="an apple",
34
+ # num_inference_steps=30,
35
+ # )
36
 
37
+ # # This will avoid future compilations on different shapes
38
+ # unet_compiled = torch._dynamo.run(pipe.unet)
39
+ # unet_compiled.config=pipe.unet.config
40
+ # unet_compiled.add_embedding = Dummy()
41
+ # unet_compiled.add_embedding.linear_1 = Dummy()
42
+ # unet_compiled.add_embedding.linear_1.in_features = pipe.unet.add_embedding.linear_1.in_features
43
+ # pipe.unet = unet_compiled
44
 
45
+ # print(f"Optimizing finished successfully after {time.time()-t} secs")
46
 
47
  @spaces.GPU(enable_queue=True)
48
+ def infer(prompt,negative_prompt,seed,resolution):
49
  print(f"""
50
  —/n
51
  {prompt}
 
53
 
54
  # generator = torch.Generator("cuda").manual_seed(555)
55
  t=time.time()
56
+ if negative_prompt=="":
57
+ negative_prompt = default_negative_prompt
58
+
59
+ if seed==-1:
60
+ generator=None
61
+ else:
62
+ generator = torch.Generator("cuda").manual_seed(seed)
63
+
64
+ w,h = resolution
65
+ image = pipe(prompt,num_inference_steps=30, negative_prompt=negative_prompt,generator=generator,width=w,height=h).images[0]
66
  print(f'gen time is {time.time()-t} secs')
67
 
68
  # Future
 
91
  with gr.Group():
92
  with gr.Column():
93
  prompt_in = gr.Textbox(label="Prompt", value="A red colored sports car")
94
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="")
95
+ resolution = gr.Dropdown(value=(1024,1024), show_label=True, label="Resolution", choices=[(1024,1024),(1344, 768)])
96
+ seed = gr.Textbox(label="Seed", value=-1)
97
  submit_btn = gr.Button("Generate")
98
  result = gr.Image(label="BRIA-2.2 Result")
99
 
 
117
  submit_btn.click(
118
  fn = infer,
119
  inputs = [
120
+ prompt_in,
121
+ negative_prompt,
122
+ seed,
123
+ resolution
124
  ],
125
  outputs = [
126
  result