|
import os |
|
import io |
|
import asyncio |
|
import socket |
|
import requests |
|
import sys |
|
import logging |
|
from fastapi import FastAPI, File, UploadFile, Form, BackgroundTasks |
|
from fastapi.responses import FileResponse, StreamingResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from PIL import Image |
|
import torch |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
AutoencoderKL, |
|
StableDiffusionControlNetPipeline, |
|
ControlNetModel, |
|
StableDiffusionLatentUpscalePipeline, |
|
StableDiffusionImg2ImgPipeline, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
DPMSolverMultistepScheduler, |
|
EulerDiscreteScheduler |
|
) |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
from transformers import AutoFeatureExtractor, CLIPFeatureExtractor |
|
import random |
|
import time |
|
import tempfile |
|
import threading |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.setLevel(logging.DEBUG) |
|
file_handler = logging.FileHandler('inference.log') |
|
stream_handler = logging.StreamHandler(sys.stdout) |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
|
stream_handler.setFormatter(formatter) |
|
|
|
|
|
logger.addHandler(file_handler) |
|
logger.addHandler(stream_handler) |
|
|
|
next_id = 0 |
|
next_id_lock = threading.Lock() |
|
|
|
class ImageGenerationQueue: |
|
def __init__(self): |
|
self.queue = asyncio.Queue() |
|
self.queue_size = 0 |
|
self.queue_lock = threading.Lock() |
|
self.next_id = 0 |
|
|
|
def add_task(self, task): |
|
asyncio.run_coroutine_threadsafe(self._add_task(task), loop=asyncio.get_event_loop()) |
|
|
|
async def _add_task(self, task): |
|
await self.queue.put(task) |
|
|
|
with self.queue_lock: |
|
self.queue_size = self.queue.qsize() |
|
|
|
async def process_queue(self): |
|
while True: |
|
task = await self.queue.get() |
|
await task() |
|
|
|
with self.queue_lock: |
|
self.queue_size = self.queue.qsize() |
|
self.queue.task_done() |
|
|
|
async def get_total_queue_size(self): |
|
|
|
with self.queue_lock: |
|
return self.queue_size |
|
|
|
app = FastAPI() |
|
queue_manager = ImageGenerationQueue() |
|
|
|
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE" |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) |
|
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16) |
|
main_pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
BASE_MODEL, |
|
controlnet=controlnet, |
|
vae=vae, |
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), |
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components) |
|
|
|
|
|
SAMPLER_MAP = { |
|
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), |
|
"Euler": lambda config: EulerDiscreteScheduler.from_config(config), |
|
} |
|
|
|
def center_crop_resize(img, output_size=(512, 512)): |
|
width, height = img.size |
|
|
|
|
|
new_dimension = min(width, height) |
|
left = (width - new_dimension)/2 |
|
top = (height - new_dimension)/2 |
|
right = (width + new_dimension)/2 |
|
bottom = (height + new_dimension)/2 |
|
|
|
|
|
img = img.crop((left, top, right, bottom)) |
|
img = img.resize(output_size) |
|
|
|
return img |
|
|
|
def common_upscale(samples, width, height, upscale_method, crop=False): |
|
if crop == "center": |
|
old_width = samples.shape[3] |
|
old_height = samples.shape[2] |
|
old_aspect = old_width / old_height |
|
new_aspect = width / height |
|
x = 0 |
|
y = 0 |
|
if old_aspect > new_aspect: |
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) |
|
elif old_aspect < new_aspect: |
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) |
|
s = samples[:,:,y:old_height-y,x:old_width-x] |
|
else: |
|
s = samples |
|
|
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) |
|
|
|
def upscale(samples, upscale_method, scale_by): |
|
|
|
width = round(samples["images"].shape[3] * scale_by) |
|
height = round(samples["images"].shape[2] * scale_by) |
|
s = common_upscale(samples["images"], width, height, upscale_method, "disabled") |
|
return (s) |
|
|
|
|
|
|
|
def convert_to_pil(base64_image): |
|
pil_image = processing_utils.decode_base64_to_image(base64_image) |
|
return pil_image |
|
|
|
def convert_to_base64(pil_image): |
|
base64_image = processing_utils.encode_pil_to_base64(pil_image) |
|
return base64_image |
|
|
|
def inference( |
|
control_image: Image.Image, |
|
prompt: str, |
|
negative_prompt = "sexual content, racism, humans, faces", |
|
guidance_scale: float = 8.0, |
|
controlnet_conditioning_scale: float = 1, |
|
control_guidance_start: float = 1, |
|
control_guidance_end: float = 1, |
|
upscaler_strength: float = 0.5, |
|
seed: int = -1, |
|
sampler = "DPM++ Karras SDE", |
|
|
|
): |
|
|
|
try: |
|
|
|
logger.debug("Input Types: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s", |
|
type(control_image), type(prompt), type(negative_prompt), type(guidance_scale), type(controlnet_conditioning_scale), |
|
type(control_guidance_start), type(control_guidance_end), type(upscaler_strength), type(seed), type(sampler)) |
|
logger.debug("Input Values: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s", |
|
control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, |
|
control_guidance_start, control_guidance_end, upscaler_strength, seed, sampler) |
|
|
|
start_time = time.time() |
|
start_time_struct = time.localtime(start_time) |
|
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct) |
|
logger.info(f"Inference started at {start_time_formatted}") |
|
|
|
|
|
control_image_small = center_crop_resize(control_image) |
|
control_image_large = center_crop_resize(control_image, (1024, 1024)) |
|
|
|
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config) |
|
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed |
|
generator = torch.Generator(device="cuda").manual_seed(my_seed) |
|
|
|
out = main_pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=control_image_small, |
|
guidance_scale=float(guidance_scale), |
|
controlnet_conditioning_scale=float(controlnet_conditioning_scale), |
|
generator=generator, |
|
control_guidance_start=float(control_guidance_start), |
|
control_guidance_end=float(control_guidance_end), |
|
num_inference_steps=15, |
|
output_type="latent" |
|
) |
|
upscaled_latents = upscale(out, "nearest-exact", 2) |
|
out_image = image_pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
control_image=control_image_large, |
|
image=upscaled_latents, |
|
guidance_scale=float(guidance_scale), |
|
generator=generator, |
|
num_inference_steps=20, |
|
strength=upscaler_strength, |
|
control_guidance_start=float(control_guidance_start), |
|
control_guidance_end=float(control_guidance_end), |
|
controlnet_conditioning_scale=float(controlnet_conditioning_scale) |
|
) |
|
end_time = time.time() |
|
end_time_struct = time.localtime(end_time) |
|
end_time_formatted = time.strftime("%H:%M:%S", end_time_struct) |
|
print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s") |
|
logger.debug("Output Types: generated_image=%s", type(None)) |
|
logger.debug("Content of out_image: %s", out_image) |
|
logger.debug("Structure of out_image: %s", dir(out_image)) |
|
if not out_image.nsfw_content_detected[0]: |
|
return out_image["images"][0] |
|
else: |
|
print("NSFW detected. Nice try.") |
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error("Error occurred during inference: %s", str(e)) |
|
return str(e) |
|
|
|
|
|
def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, controlnet_end, upscaler_strength, seed, sampler_type, image): |
|
try: |
|
|
|
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}" |
|
with open(temp_image_path, "wb") as temp_image: |
|
temp_image.write(image.file.read()) |
|
|
|
|
|
control_image_path = "scrollwhite.png" |
|
control_image = Image.open(control_image_path) |
|
|
|
|
|
generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type) |
|
|
|
|
|
output_image_io = io.BytesIO() |
|
generated_image.save(output_image_io, format="PNG") |
|
output_image_io.seek(0) |
|
output_image_binary = output_image_io.read() |
|
|
|
|
|
return output_image_binary |
|
|
|
except Exception as e: |
|
|
|
return str(e) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.post("/get_image") |
|
async def get_image( |
|
job_id: int = Form(...) |
|
): |
|
image_path = f"/tmp/{job_id}_output.png" |
|
if os.path.isfile(image_path) is False: |
|
return None |
|
|
|
with open(image_path, "rb") as file: |
|
generated_image = file.read() |
|
|
|
|
|
output_image_io = io.BytesIO() |
|
generated_image.save(output_image_io, format="PNG") |
|
output_image_io.seek(0) |
|
|
|
|
|
return StreamingResponse(content=output_image_io, media_type="image/png") |
|
|
|
|
|
@app.post("/generate_image") |
|
async def generate_image( |
|
prompt: str = Form(...), |
|
guidance_scale: float = Form(...), |
|
controlnet_scale: float = Form(...), |
|
controlnet_end: float = Form(...), |
|
upscaler_strength: float = Form(...), |
|
seed: int = Form(...), |
|
sampler_type: str = Form(...), |
|
image: UploadFile = File(...), |
|
background_tasks: BackgroundTasks = BackgroundTasks() |
|
): |
|
async def generate_image_task(job_id): |
|
global next_id_lock |
|
global next_id |
|
|
|
try: |
|
|
|
temp_image_path = f"/tmp/{job_id}_{image.filename}" |
|
with open(temp_image_path, "wb") as temp_image: |
|
temp_image.write(image.file.read()) |
|
|
|
|
|
control_image = Image.open(temp_image_path) |
|
|
|
|
|
generated_image = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type) |
|
if generated_image is None: |
|
return "Failed to generate image" |
|
|
|
output_image_path = f"/tmp/{job_id}_output.png" |
|
with open(output_image_path, "wb") as output_image: |
|
output_image.write(generated_image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
logger.error("Error occurred during image generation: %s", str(e)) |
|
return "Failed to generate image" |
|
|
|
try: |
|
with next_id_lock: |
|
id = next_id |
|
next_id += 1 |
|
|
|
background_tasks.add_task(lambda _: generate_image_task(id)) |
|
|
|
position_in_queue = queue_manager.queue.qsize() |
|
|
|
|
|
total_queue_size = await queue_manager.get_total_queue_size() |
|
|
|
return {"job_id": id, "position_in_queue": position_in_queue, "total_queue_size": total_queue_size} |
|
|
|
except Exception as e: |
|
logger.error("Error occurred during image generation: %s", str(e)) |
|
return "Failed to add task to the queue" |
|
|
|
async def start_fastapi(): |
|
|
|
internal_ip = socket.gethostbyname(socket.gethostname()) |
|
|
|
|
|
try: |
|
public_ip = requests.get("http://api.ipify.org").text |
|
except requests.RequestException: |
|
public_ip = "Not Available" |
|
|
|
print(f"Internal URL: http://{internal_ip}:7860") |
|
print(f"Public URL: http://{public_ip}:7860") |
|
|
|
|
|
queue_processing_task = asyncio.create_task(queue_manager.process_queue()) |
|
|
|
|
|
config = uvicorn.Config(app="app:app", host="0.0.0.0", port=7860, reload=True) |
|
server = uvicorn.Server(config) |
|
await server.serve() |
|
|
|
|
|
if __name__ == "__main__": |
|
queue_manager = ImageGenerationQueue() |
|
asyncio.run(start_fastapi()) |
|
|