|
import os |
|
import io |
|
import asyncio |
|
import socket |
|
import requests |
|
import sys |
|
import logging |
|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.responses import FileResponse, StreamingResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from PIL import Image |
|
import torch |
|
from diffusers import ( |
|
DiffusionPipeline, |
|
AutoencoderKL, |
|
StableDiffusionControlNetPipeline, |
|
ControlNetModel, |
|
StableDiffusionLatentUpscalePipeline, |
|
StableDiffusionImg2ImgPipeline, |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
DPMSolverMultistepScheduler, |
|
EulerDiscreteScheduler |
|
) |
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
from transformers import AutoFeatureExtractor |
|
import random |
|
import time |
|
import tempfile |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.setLevel(logging.DEBUG) |
|
file_handler = logging.FileHandler('inference.log') |
|
stream_handler = logging.StreamHandler(sys.stdout) |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
|
|
stream_handler.setFormatter(formatter) |
|
|
|
|
|
logger.addHandler(file_handler) |
|
logger.addHandler(stream_handler) |
|
|
|
app = FastAPI() |
|
|
|
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE" |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) |
|
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16) |
|
safety_model_id = "CompVis/stable-diffusion-safety-checker" |
|
main_pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
BASE_MODEL, |
|
controlnet=controlnet, |
|
vae=vae, |
|
safety_model_id = safety_model_id, |
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id), |
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id), |
|
torch_dtype=torch.float16, |
|
).to("cuda") |
|
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components) |
|
|
|
|
|
SAMPLER_MAP = { |
|
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"), |
|
"Euler": lambda config: EulerDiscreteScheduler.from_config(config), |
|
} |
|
|
|
def center_crop_resize(img, output_size=(512, 512)): |
|
width, height = img.size |
|
|
|
|
|
new_dimension = min(width, height) |
|
left = (width - new_dimension)/2 |
|
top = (height - new_dimension)/2 |
|
right = (width + new_dimension)/2 |
|
bottom = (height + new_dimension)/2 |
|
|
|
|
|
img = img.crop((left, top, right, bottom)) |
|
img = img.resize(output_size) |
|
|
|
return img |
|
|
|
def common_upscale(samples, width, height, upscale_method, crop=False): |
|
if crop == "center": |
|
old_width = samples.shape[3] |
|
old_height = samples.shape[2] |
|
old_aspect = old_width / old_height |
|
new_aspect = width / height |
|
x = 0 |
|
y = 0 |
|
if old_aspect > new_aspect: |
|
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) |
|
elif old_aspect < new_aspect: |
|
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) |
|
s = samples[:,:,y:old_height-y,x:old_width-x] |
|
else: |
|
s = samples |
|
|
|
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) |
|
|
|
def upscale(samples, upscale_method, scale_by): |
|
|
|
width = round(samples["images"].shape[3] * scale_by) |
|
height = round(samples["images"].shape[2] * scale_by) |
|
s = common_upscale(samples["images"], width, height, upscale_method, "disabled") |
|
return (s) |
|
|
|
|
|
|
|
def convert_to_pil(base64_image): |
|
pil_image = processing_utils.decode_base64_to_image(base64_image) |
|
return pil_image |
|
|
|
def convert_to_base64(pil_image): |
|
base64_image = processing_utils.encode_pil_to_base64(pil_image) |
|
return base64_image |
|
|
|
def inference( |
|
control_image: Image.Image, |
|
prompt: str, |
|
negative_prompt = "sexual content, racism, humans, faces", |
|
guidance_scale: float = 8.0, |
|
controlnet_conditioning_scale: float = 1, |
|
control_guidance_start: float = 1, |
|
control_guidance_end: float = 1, |
|
upscaler_strength: float = 0.5, |
|
seed: int = -1, |
|
sampler = "DPM++ Karras SDE", |
|
|
|
): |
|
|
|
try: |
|
|
|
logger.debug("Input Types: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s", |
|
type(control_image), type(prompt), type(negative_prompt), type(guidance_scale), type(controlnet_conditioning_scale), |
|
type(control_guidance_start), type(control_guidance_end), type(upscaler_strength), type(seed), type(sampler)) |
|
logger.debug("Input Values: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s", |
|
control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, |
|
control_guidance_start, control_guidance_end, upscaler_strength, seed, sampler) |
|
|
|
start_time = time.time() |
|
start_time_struct = time.localtime(start_time) |
|
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct) |
|
logger.info(f"Inference started at {start_time_formatted}") |
|
|
|
|
|
control_image_small = center_crop_resize(control_image) |
|
control_image_large = center_crop_resize(control_image, (1024, 1024)) |
|
|
|
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config) |
|
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed |
|
generator = torch.Generator(device="cuda").manual_seed(my_seed) |
|
|
|
out = main_pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=control_image_small, |
|
guidance_scale=float(guidance_scale), |
|
controlnet_conditioning_scale=float(controlnet_conditioning_scale), |
|
generator=generator, |
|
control_guidance_start=float(control_guidance_start), |
|
control_guidance_end=float(control_guidance_end), |
|
num_inference_steps=15, |
|
output_type="latent" |
|
) |
|
upscaled_latents = upscale(out, "nearest-exact", 2) |
|
out_image = image_pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
control_image=control_image_large, |
|
image=upscaled_latents, |
|
guidance_scale=float(guidance_scale), |
|
generator=generator, |
|
num_inference_steps=20, |
|
strength=upscaler_strength, |
|
control_guidance_start=float(control_guidance_start), |
|
control_guidance_end=float(control_guidance_end), |
|
controlnet_conditioning_scale=float(controlnet_conditioning_scale) |
|
) |
|
end_time = time.time() |
|
end_time_struct = time.localtime(end_time) |
|
end_time_formatted = time.strftime("%H:%M:%S", end_time_struct) |
|
print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s") |
|
logger.debug("Output Types: generated_image=%s", type(None)) |
|
logger.debug("Content of out_image: %s", out_image) |
|
logger.debug("Structure of out_image: %s", dir(out_image)) |
|
return out_image["images"][0] |
|
|
|
|
|
except Exception as e: |
|
|
|
logger.error("Error occurred during inference: %s", str(e)) |
|
return str(e) |
|
|
|
|
|
def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, controlnet_end, upscaler_strength, seed, sampler_type, image): |
|
try: |
|
|
|
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}" |
|
with open(temp_image_path, "wb") as temp_image: |
|
temp_image.write(image.file.read()) |
|
|
|
|
|
control_image_path = "scrollwhite.png" |
|
control_image = Image.open(control_image_path) |
|
|
|
|
|
generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type) |
|
|
|
|
|
output_image_io = io.BytesIO() |
|
generated_image.save(output_image_io, format="PNG") |
|
output_image_io.seek(0) |
|
output_image_binary = output_image_io.read() |
|
|
|
|
|
logger.debug("Output Values: generated_image=<binary data>") |
|
return output_image_binary |
|
|
|
except Exception as e: |
|
|
|
return str(e) |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
@app.post("/generate_image") |
|
async def generate_image( |
|
prompt: str = Form(...), |
|
guidance_scale: float = Form(...), |
|
controlnet_scale: float = Form(...), |
|
controlnet_end: float = Form(...), |
|
upscaler_strength: float = Form(...), |
|
seed: int = Form(...), |
|
sampler_type: str = Form(...), |
|
image: UploadFile = File(...) |
|
): |
|
try: |
|
|
|
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}" |
|
with open(temp_image_path, "wb") as temp_image: |
|
temp_image.write(image.file.read()) |
|
|
|
|
|
control_image = Image.open(temp_image_path) |
|
|
|
|
|
generated_image = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type) |
|
if generated_image is None: |
|
return "Failed to generate image" |
|
|
|
|
|
output_image_io = io.BytesIO() |
|
generated_image.save(output_image_io, format="PNG") |
|
output_image_io.seek(0) |
|
|
|
|
|
return StreamingResponse(content=output_image_io, media_type="image/png") |
|
|
|
except Exception as e: |
|
logger.error("Error occurred during image generation: %s", str(e)) |
|
return "Failed to generate image" |
|
|
|
async def start_fastapi(): |
|
|
|
internal_ip = socket.gethostbyname(socket.gethostname()) |
|
|
|
|
|
try: |
|
public_ip = requests.get("http://api.ipify.org").text |
|
except requests.RequestException: |
|
public_ip = "Not Available" |
|
|
|
print(f"Internal URL: http://{internal_ip}:7860") |
|
print(f"Public URL: http://{public_ip}:7860") |
|
|
|
|
|
config = uvicorn.Config(app="app:app", host="0.0.0.0", port=7860, reload=True) |
|
server = uvicorn.Server(config) |
|
await server.serve() |
|
|
|
|
|
if __name__ == "__main__": |
|
asyncio.run(start_fastapi()) |
|
|