vittore commited on
Commit
c811dfe
·
1 Parent(s): f165e6e

torch with gpu

Browse files
Files changed (1) hide show
  1. app.py +158 -4
app.py CHANGED
@@ -25,10 +25,7 @@ from illusion_style import css
25
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
26
 
27
 
28
- if torch.cuda.is_available():
29
- device='gpu'
30
- else:
31
- device='cpu'
32
 
33
  # Initialize both pipelines
34
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
@@ -42,6 +39,163 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
42
  torch_dtype=torch.float16,
43
  ).to(device)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def greet(name):
46
  return "Hello " + name + "!!"
47
 
 
25
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
26
 
27
 
28
+ device='gpu'
 
 
 
29
 
30
  # Initialize both pipelines
31
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
 
39
  torch_dtype=torch.float16,
40
  ).to(device)
41
 
42
+ #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
43
+ #main_pipe.unet.to(memory_format=torch.channels_last)
44
+ #main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
45
+ #model_id = "stabilityai/sd-x2-latent-upscaler"
46
+ image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
47
+
48
+
49
+ #image_pipe.unet = torch.compile(image_pipe.unet, mode="reduce-overhead", fullgraph=True)
50
+ #upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
51
+ #upscaler.to("cuda")
52
+
53
+
54
+ # Sampler map
55
+ SAMPLER_MAP = {
56
+ "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
57
+ "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
58
+ }
59
+
60
+ def center_crop_resize(img, output_size=(512, 512)):
61
+ width, height = img.size
62
+
63
+ # Calculate dimensions to crop to the center
64
+ new_dimension = min(width, height)
65
+ left = (width - new_dimension)/2
66
+ top = (height - new_dimension)/2
67
+ right = (width + new_dimension)/2
68
+ bottom = (height + new_dimension)/2
69
+
70
+ # Crop and resize
71
+ img = img.crop((left, top, right, bottom))
72
+ img = img.resize(output_size)
73
+
74
+ return img
75
+
76
+ def common_upscale(samples, width, height, upscale_method, crop=False):
77
+ if crop == "center":
78
+ old_width = samples.shape[3]
79
+ old_height = samples.shape[2]
80
+ old_aspect = old_width / old_height
81
+ new_aspect = width / height
82
+ x = 0
83
+ y = 0
84
+ if old_aspect > new_aspect:
85
+ x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
86
+ elif old_aspect < new_aspect:
87
+ y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
88
+ s = samples[:,:,y:old_height-y,x:old_width-x]
89
+ else:
90
+ s = samples
91
+
92
+ return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
93
+
94
+ def upscale(samples, upscale_method, scale_by):
95
+ #s = samples.copy()
96
+ width = round(samples["images"].shape[3] * scale_by)
97
+ height = round(samples["images"].shape[2] * scale_by)
98
+ s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
99
+ return (s)
100
+
101
+ def check_inputs(prompt: str, control_image: Image.Image):
102
+ if control_image is None:
103
+ raise gr.Error("Please select or upload an Input Illusion")
104
+ if prompt is None or prompt == "":
105
+ raise gr.Error("Prompt is required")
106
+
107
+ def convert_to_pil(base64_image):
108
+ pil_image = Image.open(base64_image)
109
+ return pil_image
110
+
111
+ def convert_to_base64(pil_image):
112
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
113
+ image.save(temp_file.name)
114
+ return temp_file.name
115
+
116
+ # Inference function
117
+ @spaces.GPU
118
+ def inference(
119
+ control_image: Image.Image,
120
+ prompt: str,
121
+ negative_prompt: str,
122
+ guidance_scale: float = 8.0,
123
+ controlnet_conditioning_scale: float = 1,
124
+ control_guidance_start: float = 1,
125
+ control_guidance_end: float = 1,
126
+ upscaler_strength: float = 0.5,
127
+ seed: int = -1,
128
+ sampler = "DPM++ Karras SDE",
129
+ progress = gr.Progress(track_tqdm=True),
130
+ profile: gr.OAuthProfile | None = None,
131
+ ):
132
+ start_time = time.time()
133
+ start_time_struct = time.localtime(start_time)
134
+ start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
135
+ print(f"Inference started at {start_time_formatted}")
136
+
137
+ # Generate the initial image
138
+ #init_image = init_pipe(prompt).images[0]
139
+
140
+ # Rest of your existing code
141
+ control_image_small = center_crop_resize(control_image)
142
+ control_image_large = center_crop_resize(control_image, (1024, 1024))
143
+
144
+ main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
145
+ my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
146
+ generator = torch.Generator(device=device).manual_seed(my_seed)
147
+
148
+ out = main_pipe(
149
+ prompt=prompt,
150
+ negative_prompt=negative_prompt,
151
+ image=control_image_small,
152
+ guidance_scale=float(guidance_scale),
153
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
154
+ generator=generator,
155
+ control_guidance_start=float(control_guidance_start),
156
+ control_guidance_end=float(control_guidance_end),
157
+ num_inference_steps=15,
158
+ output_type="latent"
159
+ )
160
+ upscaled_latents = upscale(out, "nearest-exact", 2)
161
+ out_image = image_pipe(
162
+ prompt=prompt,
163
+ negative_prompt=negative_prompt,
164
+ control_image=control_image_large,
165
+ image=upscaled_latents,
166
+ guidance_scale=float(guidance_scale),
167
+ generator=generator,
168
+ num_inference_steps=20,
169
+ strength=upscaler_strength,
170
+ control_guidance_start=float(control_guidance_start),
171
+ control_guidance_end=float(control_guidance_end),
172
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale)
173
+ )
174
+ end_time = time.time()
175
+ end_time_struct = time.localtime(end_time)
176
+ end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
177
+ print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
178
+
179
+ # Save image + metadata
180
+ user_history.save_image(
181
+ label=prompt,
182
+ image=out_image["images"][0],
183
+ profile=profile,
184
+ metadata={
185
+ "prompt": prompt,
186
+ "negative_prompt": negative_prompt,
187
+ "guidance_scale": guidance_scale,
188
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
189
+ "control_guidance_start": control_guidance_start,
190
+ "control_guidance_end": control_guidance_end,
191
+ "upscaler_strength": upscaler_strength,
192
+ "seed": seed,
193
+ "sampler": sampler,
194
+ },
195
+ )
196
+
197
+ return out_image["images"][0], gr.update(visible=True), gr.update(visible=True), my_seed
198
+
199
  def greet(name):
200
  return "Hello " + name + "!!"
201