import os import io import asyncio import socket import requests import sys import logging from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import FileResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from PIL import Image import torch from diffusers import ( DiffusionPipeline, AutoencoderKL, StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionLatentUpscalePipeline, StableDiffusionImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, DPMSolverMultistepScheduler, EulerDiscreteScheduler ) from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor import random import time import tempfile logger = logging.getLogger(__name__) # Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) logger.setLevel(logging.DEBUG) file_handler = logging.FileHandler('inference.log') stream_handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Set the formatter for the stream handler (command line) stream_handler.setFormatter(formatter) # Add the file handler and stream handler to the logger logger.addHandler(file_handler) logger.addHandler(stream_handler) app = FastAPI() BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE" # Initialize both pipelines vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16) safety_model_id = "CompVis/stable-diffusion-safety-checker" main_pipe = StableDiffusionControlNetPipeline.from_pretrained( BASE_MODEL, controlnet=controlnet, vae=vae, safety_model_id = safety_model_id, safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id), safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id), torch_dtype=torch.float16, ).to("cuda") image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components) # Sampler map SAMPLER_MAP = { "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), "Euler": lambda config: EulerDiscreteScheduler.from_config(config), } def center_crop_resize(img, output_size=(512, 512)): width, height = img.size # Calculate dimensions to crop to the center new_dimension = min(width, height) left = (width - new_dimension)/2 top = (height - new_dimension)/2 right = (width + new_dimension)/2 bottom = (height + new_dimension)/2 # Crop and resize img = img.crop((left, top, right, bottom)) img = img.resize(output_size) return img def common_upscale(samples, width, height, upscale_method, crop=False): if crop == "center": old_width = samples.shape[3] old_height = samples.shape[2] old_aspect = old_width / old_height new_aspect = width / height x = 0 y = 0 if old_aspect > new_aspect: x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) elif old_aspect < new_aspect: y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) s = samples[:,:,y:old_height-y,x:old_width-x] else: s = samples return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def upscale(samples, upscale_method, scale_by): #s = samples.copy() width = round(samples["images"].shape[3] * scale_by) height = round(samples["images"].shape[2] * scale_by) s = common_upscale(samples["images"], width, height, upscale_method, "disabled") return (s) # def convert_to_pil(base64_image): pil_image = processing_utils.decode_base64_to_image(base64_image) return pil_image def convert_to_base64(pil_image): base64_image = processing_utils.encode_pil_to_base64(pil_image) return base64_image def inference( control_image: Image.Image, prompt: str, negative_prompt = "sexual content, racism, humans, faces", guidance_scale: float = 8.0, controlnet_conditioning_scale: float = 1, control_guidance_start: float = 1, control_guidance_end: float = 1, upscaler_strength: float = 0.5, seed: int = -1, sampler = "DPM++ Karras SDE", #profile: gr.OAuthProfile | None = None, ): try: # Log input types and values logger.debug("Input Types: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s", type(control_image), type(prompt), type(negative_prompt), type(guidance_scale), type(controlnet_conditioning_scale), type(control_guidance_start), type(control_guidance_end), type(upscaler_strength), type(seed), type(sampler)) logger.debug("Input Values: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s", control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, upscaler_strength, seed, sampler) start_time = time.time() start_time_struct = time.localtime(start_time) start_time_formatted = time.strftime("%H:%M:%S", start_time_struct) logger.info(f"Inference started at {start_time_formatted}") # Rest of your existing code control_image_small = center_crop_resize(control_image) control_image_large = center_crop_resize(control_image, (1024, 1024)) main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config) my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed generator = torch.Generator(device="cuda").manual_seed(my_seed) out = main_pipe( prompt=prompt, negative_prompt=negative_prompt, image=control_image_small, guidance_scale=float(guidance_scale), controlnet_conditioning_scale=float(controlnet_conditioning_scale), generator=generator, control_guidance_start=float(control_guidance_start), control_guidance_end=float(control_guidance_end), num_inference_steps=15, output_type="latent" ) upscaled_latents = upscale(out, "nearest-exact", 2) out_image = image_pipe( prompt=prompt, negative_prompt=negative_prompt, control_image=control_image_large, image=upscaled_latents, guidance_scale=float(guidance_scale), generator=generator, num_inference_steps=20, strength=upscaler_strength, control_guidance_start=float(control_guidance_start), control_guidance_end=float(control_guidance_end), controlnet_conditioning_scale=float(controlnet_conditioning_scale) ) end_time = time.time() end_time_struct = time.localtime(end_time) end_time_formatted = time.strftime("%H:%M:%S", end_time_struct) print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s") logger.debug("Output Types: generated_image=%s", type(None)) logger.debug("Content of out_image: %s", out_image) logger.debug("Structure of out_image: %s", dir(out_image)) return out_image["images"][0] except Exception as e: # Handle exceptions and log error message logger.error("Error occurred during inference: %s", str(e)) return str(e) def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, controlnet_end, upscaler_strength, seed, sampler_type, image): try: # Save the uploaded image to a temporary file temp_image_path = f"/tmp/{int(time.time())}_{image.filename}" with open(temp_image_path, "wb") as temp_image: temp_image.write(image.file.read()) # Open the uploaded image using PIL control_image_path = "scrollwhite.png" control_image = Image.open(control_image_path) # Call existing inference function with the provided parameters generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type) # Save the generated image as binary data output_image_io = io.BytesIO() generated_image.save(output_image_io, format="PNG") output_image_io.seek(0) output_image_binary = output_image_io.read() # Return the generated image binary data logger.debug("Output Values: generated_image=") return output_image_binary except Exception as e: # Handle exceptions and return an error message if something goes wrong return str(e) app.add_middleware( CORSMiddleware, allow_origins=["*"], # You can replace ["*"] with specific origins if needed allow_credentials=True, allow_methods=["*"], # Allow all methods allow_headers=["*"], # Allow all headers ) @app.post("/generate_image") async def generate_image( prompt: str = Form(...), guidance_scale: float = Form(...), controlnet_scale: float = Form(...), controlnet_end: float = Form(...), upscaler_strength: float = Form(...), seed: int = Form(...), sampler_type: str = Form(...), image: UploadFile = File(...) ): try: # Save the uploaded image to a temporary file temp_image_path = f"/tmp/{int(time.time())}_{image.filename}" with open(temp_image_path, "wb") as temp_image: temp_image.write(image.file.read()) # Open the uploaded image using PIL control_image = Image.open(temp_image_path) # Call existing inference function with the provided parameters generated_image = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type) if generated_image is None: return "Failed to generate image" # Save the generated image as binary data output_image_io = io.BytesIO() generated_image.save(output_image_io, format="PNG") output_image_io.seek(0) # Return the image as a streaming response return StreamingResponse(content=output_image_io, media_type="image/png") except Exception as e: logger.error("Error occurred during image generation: %s", str(e)) return "Failed to generate image" async def start_fastapi(): # Get internal IP address internal_ip = socket.gethostbyname(socket.gethostname()) # Get public IP address using a public API (this may not work if you are behind a router/NAT) try: public_ip = requests.get("http://api.ipify.org").text except requests.RequestException: public_ip = "Not Available" print(f"Internal URL: http://{internal_ip}:7860") print(f"Public URL: http://{public_ip}:7860") # Run FastAPI using hypercorn config = uvicorn.Config(app="app:app", host="0.0.0.0", port=7860, reload=True) server = uvicorn.Server(config) await server.serve() # Call the asynchronous function using asyncio.run() if __name__ == "__main__": asyncio.run(start_fastapi())