pseudotheos commited on
Commit
aff1d7c
1 Parent(s): 3f84c0a
Files changed (2) hide show
  1. app.py +274 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import socket
3
+ import requests
4
+ from fastapi import FastAPI, File, UploadFile, Form
5
+ from fastapi.responses import FileResponse
6
+ from PIL import Image
7
+ import torch
8
+ from diffusers import (
9
+ DiffusionPipeline,
10
+ AutoencoderKL,
11
+ StableDiffusionControlNetPipeline,
12
+ ControlNetModel,
13
+ StableDiffusionLatentUpscalePipeline,
14
+ StableDiffusionImg2ImgPipeline,
15
+ StableDiffusionControlNetImg2ImgPipeline,
16
+ DPMSolverMultistepScheduler,
17
+ EulerDiscreteScheduler
18
+ )
19
+ import random
20
+ import time
21
+ import tempfile
22
+
23
+ app = FastAPI()
24
+
25
+ BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
26
+
27
+ # Initialize both pipelines
28
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
29
+ controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
30
+ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
31
+ BASE_MODEL,
32
+ controlnet=controlnet,
33
+ vae=vae,
34
+ safety_checker=None,
35
+ torch_dtype=torch.float16,
36
+ ).to("cuda")
37
+ image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
38
+
39
+ # Sampler map
40
+ SAMPLER_MAP = {
41
+ "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
42
+ "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
43
+ }
44
+
45
+ def center_crop_resize(img, output_size=(512, 512)):
46
+ width, height = img.size
47
+
48
+ # Calculate dimensions to crop to the center
49
+ new_dimension = min(width, height)
50
+ left = (width - new_dimension)/2
51
+ top = (height - new_dimension)/2
52
+ right = (width + new_dimension)/2
53
+ bottom = (height + new_dimension)/2
54
+
55
+ # Crop and resize
56
+ img = img.crop((left, top, right, bottom))
57
+ img = img.resize(output_size)
58
+
59
+ return img
60
+
61
+ def common_upscale(samples, width, height, upscale_method, crop=False):
62
+ if crop == "center":
63
+ old_width = samples.shape[3]
64
+ old_height = samples.shape[2]
65
+ old_aspect = old_width / old_height
66
+ new_aspect = width / height
67
+ x = 0
68
+ y = 0
69
+ if old_aspect > new_aspect:
70
+ x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
71
+ elif old_aspect < new_aspect:
72
+ y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
73
+ s = samples[:,:,y:old_height-y,x:old_width-x]
74
+ else:
75
+ s = samples
76
+
77
+ return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
78
+
79
+ def upscale(samples, upscale_method, scale_by):
80
+ #s = samples.copy()
81
+ width = round(samples["images"].shape[3] * scale_by)
82
+ height = round(samples["images"].shape[2] * scale_by)
83
+ s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
84
+ return (s)
85
+
86
+ #
87
+
88
+ def convert_to_pil(base64_image):
89
+ pil_image = processing_utils.decode_base64_to_image(base64_image)
90
+ return pil_image
91
+
92
+ def convert_to_base64(pil_image):
93
+ base64_image = processing_utils.encode_pil_to_base64(pil_image)
94
+ return base64_image
95
+
96
+ # Inference function
97
+ def inference(
98
+ control_image: Image.Image,
99
+ prompt: str,
100
+ negative_prompt: str,
101
+ guidance_scale: float = 8.0,
102
+ controlnet_conditioning_scale: float = 1,
103
+ control_guidance_start: float = 1,
104
+ control_guidance_end: float = 1,
105
+ upscaler_strength: float = 0.5,
106
+ seed: int = -1,
107
+ sampler = "DPM++ Karras SDE",
108
+ #profile: gr.OAuthProfile | None = None,
109
+ ):
110
+ start_time = time.time()
111
+ start_time_struct = time.localtime(start_time)
112
+ start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
113
+ print(f"Inference started at {start_time_formatted}")
114
+
115
+ # Generate the initial image
116
+ #init_image = init_pipe(prompt).images[0]
117
+
118
+ # Rest of your existing code
119
+ control_image_small = center_crop_resize(control_image)
120
+ control_image_large = center_crop_resize(control_image, (1024, 1024))
121
+
122
+ main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
123
+ my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
124
+ generator = torch.Generator(device="cuda").manual_seed(my_seed)
125
+
126
+ out = main_pipe(
127
+ prompt=prompt,
128
+ negative_prompt=negative_prompt,
129
+ image=control_image_small,
130
+ guidance_scale=float(guidance_scale),
131
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale),
132
+ generator=generator,
133
+ control_guidance_start=float(control_guidance_start),
134
+ control_guidance_end=float(control_guidance_end),
135
+ num_inference_steps=15,
136
+ output_type="latent"
137
+ )
138
+ upscaled_latents = upscale(out, "nearest-exact", 2)
139
+ out_image = image_pipe(
140
+ prompt=prompt,
141
+ negative_prompt=negative_prompt,
142
+ control_image=control_image_large,
143
+ image=upscaled_latents,
144
+ guidance_scale=float(guidance_scale),
145
+ generator=generator,
146
+ num_inference_steps=20,
147
+ strength=upscaler_strength,
148
+ control_guidance_start=float(control_guidance_start),
149
+ control_guidance_end=float(control_guidance_end),
150
+ controlnet_conditioning_scale=float(controlnet_conditioning_scale)
151
+ )
152
+ end_time = time.time()
153
+ end_time_struct = time.localtime(end_time)
154
+ end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
155
+ print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
156
+
157
+ # Save image + metadata
158
+ user_history.save_image(
159
+ label=prompt,
160
+ image=out_image["images"][0],
161
+ profile=profile,
162
+ metadata={
163
+ "prompt": prompt,
164
+ "negative_prompt": negative_prompt,
165
+ "guidance_scale": guidance_scale,
166
+ "controlnet_conditioning_scale": controlnet_conditioning_scale,
167
+ "control_guidance_start": control_guidance_start,
168
+ "control_guidance_end": control_guidance_end,
169
+ "upscaler_strength": upscaler_strength,
170
+ "seed": seed,
171
+ "sampler": sampler,
172
+ },
173
+ )
174
+
175
+ return out_image["images"][0], my_seed
176
+
177
+ import os
178
+
179
+ def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, controlnet_end, upscaler_strength, seed, sampler_type, image):
180
+ try:
181
+ # Save the uploaded image to a temporary file
182
+ temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
183
+ with open(temp_image_path, "wb") as temp_image:
184
+ temp_image.write(image.file.read())
185
+
186
+ # Open the uploaded image using PIL
187
+ control_image = Image.open(temp_image_path)
188
+
189
+ # Call existing inference function with the provided parameters
190
+ generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type)
191
+
192
+ # Specify the desired output directory for saving generated images
193
+ output_directory = "/home/user/app/generated_files"
194
+
195
+ # Create the output directory if it doesn't exist
196
+ os.makedirs(output_directory, exist_ok=True)
197
+
198
+ # Generate a unique filename for the saved image
199
+ filename = f"generated_image_{int(time.time())}.png"
200
+
201
+ # Save the generated image to the permanent location
202
+ output_path = os.path.join(output_directory, filename)
203
+ generated_image.save(output_path, format="PNG")
204
+
205
+ # Return the generated image path
206
+ return output_path
207
+
208
+ except Exception as e:
209
+ # Handle exceptions and return an error message if something goes wrong
210
+ return str(e)
211
+
212
+ @app.post("/generate_image")
213
+ async def generate_image(
214
+ prompt: str = Form(...),
215
+ guidance_scale: float = Form(...),
216
+ controlnet_scale: float = Form(...),
217
+ controlnet_end: float = Form(...),
218
+ upscaler_strength: float = Form(...),
219
+ seed: int = Form(...),
220
+ sampler_type: str = Form(...),
221
+ image: UploadFile = File(...)
222
+ ):
223
+ try:
224
+ # Save the uploaded image to a temporary file
225
+ temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
226
+ with open(temp_image_path, "wb") as temp_image:
227
+ temp_image.write(image.file.read())
228
+
229
+ # Open the uploaded image using PIL
230
+ control_image = Image.open(temp_image_path)
231
+
232
+ # Call existing inference function with the provided parameters
233
+ generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type)
234
+
235
+ # Specify the desired output directory for saving generated images
236
+ output_directory = "/home/user/app/generated_files"
237
+
238
+ # Create the output directory if it doesn't exist
239
+ os.makedirs(output_directory, exist_ok=True)
240
+
241
+ # Generate a unique filename for the saved image
242
+ filename = f"generated_image_{int(time.time())}.png"
243
+
244
+ # Save the generated image to the permanent location
245
+ output_path = os.path.join(output_directory, filename)
246
+ generated_image.save(output_path, format="PNG")
247
+
248
+ # Return the generated image path
249
+ return output_path
250
+
251
+ except Exception as e:
252
+ # Handle exceptions and return an error message if something goes wrong
253
+ return str(e)
254
+
255
+ if __name__ == "__main__":
256
+ import uvicorn
257
+
258
+ # Get internal IP address
259
+ internal_ip = socket.gethostbyname(socket.gethostname())
260
+
261
+ # Get public IP address using a public API (this may not work if you are behind a router/NAT)
262
+ try:
263
+ public_ip = requests.get("http://api.ipify.org").text
264
+ except requests.RequestException:
265
+ public_ip = "Not Available"
266
+
267
+ print(f"Internal URL: http://{internal_ip}:8000")
268
+ print(f"Public URL: http://{public_ip}:8000")
269
+
270
+ uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
271
+
272
+ if __name__ == "__main__":
273
+ import uvicorn
274
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
requirements.txt ADDED
File without changes