pseudotheos's picture
Update app.py
075b52e
raw
history blame
14.8 kB
import os
import io
import asyncio
import socket
import requests
import sys
import logging
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
StableDiffusionControlNetPipeline,
ControlNetModel,
StableDiffusionLatentUpscalePipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionControlNetImg2ImgPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler
)
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor, CLIPFeatureExtractor
import random
import time
import tempfile
import threading
logger = logging.getLogger(__name__)
# Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler('inference.log')
stream_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Set the formatter for the stream handler (command line)
stream_handler.setFormatter(formatter)
# Add the file handler and stream handler to the logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
next_id = 0
next_id_lock = threading.Lock()
class ImageGenerationQueue:
def __init__(self):
self.queue = asyncio.Queue()
self.queue_size = 0
self.queue_lock = threading.Lock()
self.next_id = 0
def add_task(self, task):
asyncio.run_coroutine_threadsafe(self._add_task(task), loop=asyncio.get_event_loop())
async def _add_task(self, task):
await self.queue.put(task)
# Update queue size in a thread-safe manner
with self.queue_lock:
self.queue_size = self.queue.qsize()
async def process_queue(self):
while True:
task = await self.queue.get()
await task()
# Update queue size in a thread-safe manner
with self.queue_lock:
self.queue_size = self.queue.qsize()
self.queue.task_done()
async def get_total_queue_size(self):
# Return the queue size in a thread-safe manner
with self.queue_lock:
return self.queue_size
app = FastAPI()
queue_manager = ImageGenerationQueue()
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
# Initialize both pipelines
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_MODEL,
controlnet=controlnet,
vae=vae,
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
torch_dtype=torch.float16,
).to("cuda")
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
# Sampler map
SAMPLER_MAP = {
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
}
def center_crop_resize(img, output_size=(512, 512)):
width, height = img.size
# Calculate dimensions to crop to the center
new_dimension = min(width, height)
left = (width - new_dimension)/2
top = (height - new_dimension)/2
right = (width + new_dimension)/2
bottom = (height + new_dimension)/2
# Crop and resize
img = img.crop((left, top, right, bottom))
img = img.resize(output_size)
return img
def common_upscale(samples, width, height, upscale_method, crop=False):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def upscale(samples, upscale_method, scale_by):
#s = samples.copy()
width = round(samples["images"].shape[3] * scale_by)
height = round(samples["images"].shape[2] * scale_by)
s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
return (s)
#
def convert_to_pil(base64_image):
pil_image = processing_utils.decode_base64_to_image(base64_image)
return pil_image
def convert_to_base64(pil_image):
base64_image = processing_utils.encode_pil_to_base64(pil_image)
return base64_image
def inference(
control_image: Image.Image,
prompt: str,
negative_prompt = "sexual content, racism, humans, faces",
guidance_scale: float = 8.0,
controlnet_conditioning_scale: float = 1,
control_guidance_start: float = 1,
control_guidance_end: float = 1,
upscaler_strength: float = 0.5,
seed: int = -1,
sampler = "DPM++ Karras SDE",
#profile: gr.OAuthProfile | None = None,
):
try:
# Log input types and values
logger.debug("Input Types: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s",
type(control_image), type(prompt), type(negative_prompt), type(guidance_scale), type(controlnet_conditioning_scale),
type(control_guidance_start), type(control_guidance_end), type(upscaler_strength), type(seed), type(sampler))
logger.debug("Input Values: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s",
control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_end, upscaler_strength, seed, sampler)
start_time = time.time()
start_time_struct = time.localtime(start_time)
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
logger.info(f"Inference started at {start_time_formatted}")
# Rest of your existing code
control_image_small = center_crop_resize(control_image)
control_image_large = center_crop_resize(control_image, (1024, 1024))
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
generator = torch.Generator(device="cuda").manual_seed(my_seed)
out = main_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image_small,
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
generator=generator,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
num_inference_steps=15,
output_type="latent"
)
upscaled_latents = upscale(out, "nearest-exact", 2)
out_image = image_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
control_image=control_image_large,
image=upscaled_latents,
guidance_scale=float(guidance_scale),
generator=generator,
num_inference_steps=20,
strength=upscaler_strength,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
controlnet_conditioning_scale=float(controlnet_conditioning_scale)
)
end_time = time.time()
end_time_struct = time.localtime(end_time)
end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
logger.debug("Output Types: generated_image=%s", type(None))
logger.debug("Content of out_image: %s", out_image)
logger.debug("Structure of out_image: %s", dir(out_image))
if not out_image.nsfw_content_detected[0]:
return out_image["images"][0]
else:
print("NSFW detected. Nice try.")
except Exception as e:
# Handle exceptions and log error message
logger.error("Error occurred during inference: %s", str(e))
return str(e)
def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, controlnet_end, upscaler_strength, seed, sampler_type, image):
try:
# Save the uploaded image to a temporary file
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
with open(temp_image_path, "wb") as temp_image:
temp_image.write(image.file.read())
# Open the uploaded image using PIL
control_image_path = "scrollwhite.png"
control_image = Image.open(control_image_path)
# Call existing inference function with the provided parameters
generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type)
# Save the generated image as binary data
output_image_io = io.BytesIO()
generated_image.save(output_image_io, format="PNG")
output_image_io.seek(0)
output_image_binary = output_image_io.read()
# Return the generated image binary data
return output_image_binary
except Exception as e:
# Handle exceptions and return an error message if something goes wrong
return str(e)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # You can replace ["*"] with specific origins if needed
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
@app.post("/get_image")
async def get_image(
job_id: int = Form(...)
):
image_path = f"/tmp/{job_id}_output.png"
if os.path.isfile(image_path) is False:
return None
with open(image_path, "rb") as file:
generated_image = file.read()
# Save the generated image as binary data
output_image_io = io.BytesIO()
generated_image.save(output_image_io, format="PNG")
output_image_io.seek(0)
# Return the image as a streaming response
return StreamingResponse(content=output_image_io, media_type="image/png")
@app.post("/generate_image")
async def generate_image(
prompt: str = Form(...),
guidance_scale: float = Form(...),
controlnet_scale: float = Form(...),
controlnet_end: float = Form(...),
upscaler_strength: float = Form(...),
seed: int = Form(...),
sampler_type: str = Form(...),
image: UploadFile = File(...),
background_tasks: BackgroundTasks = BackgroundTasks()
):
async def generate_image_task(job_id):
global next_id_lock
global next_id
try:
# Save the uploaded image to a temporary file
temp_image_path = f"/tmp/{job_id}_{image.filename}"
with open(temp_image_path, "wb") as temp_image:
temp_image.write(image.file.read())
# Open the uploaded image using PIL
control_image = Image.open(temp_image_path)
# Call existing inference function with the provided parameters
generated_image = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type)
if generated_image is None:
return "Failed to generate image"
output_image_path = f"/tmp/{job_id}_output.png"
with open(output_image_path, "wb") as output_image:
output_image.write(generated_image)
# # Save the generated image as binary data
# output_image_io = io.BytesIO()
# generated_image.save(output_image_io, format="PNG")
# output_image_io.seek(0)
# # Return the image as a streaming response
# return StreamingResponse(content=output_image_io, media_type="image/png")
except Exception as e:
logger.error("Error occurred during image generation: %s", str(e))
return "Failed to generate image"
try:
with next_id_lock:
id = next_id
next_id += 1
background_tasks.add_task(lambda _: generate_image_task(id))
# queue_manager.add_task(lambda _: generate_image_task(id))
position_in_queue = queue_manager.queue.qsize()
# Total queue size is still async
total_queue_size = await queue_manager.get_total_queue_size() # Implement this function
return {"job_id": id, "position_in_queue": position_in_queue, "total_queue_size": total_queue_size}
except Exception as e:
logger.error("Error occurred during image generation: %s", str(e))
return "Failed to add task to the queue"
async def start_fastapi():
# Get internal IP address
internal_ip = socket.gethostbyname(socket.gethostname())
# Get public IP address using a public API (this may not work if you are behind a router/NAT)
try:
public_ip = requests.get("http://api.ipify.org").text
except requests.RequestException:
public_ip = "Not Available"
print(f"Internal URL: http://{internal_ip}:7860")
print(f"Public URL: http://{public_ip}:7860")
# Start processing the existing image generation queue
queue_processing_task = asyncio.create_task(queue_manager.process_queue())
# Run FastAPI using hypercorn
config = uvicorn.Config(app="app:app", host="0.0.0.0", port=7860, reload=True)
server = uvicorn.Server(config)
await server.serve()
# Call the asynchronous function using asyncio.run()
if __name__ == "__main__":
queue_manager = ImageGenerationQueue() # Create the queue_manager instance here
asyncio.run(start_fastapi())