pseudotheos's picture
Update app.py
e52fd17
raw
history blame
10.8 kB
import os
import io
import socket
import requests
import sys
import logging
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse
from PIL import Image
import torch
from diffusers import (
DiffusionPipeline,
AutoencoderKL,
StableDiffusionControlNetPipeline,
ControlNetModel,
StableDiffusionLatentUpscalePipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionControlNetImg2ImgPipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler
)
import random
import time
import tempfile
logger = logging.getLogger(__name__)
# Set the logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
logger.setLevel(logging.DEBUG)
file_handler = logging.FileHandler('inference.log')
stream_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Set the formatter for the stream handler (command line)
stream_handler.setFormatter(formatter)
# Add the file handler and stream handler to the logger
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
app = FastAPI()
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
# Initialize both pipelines
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_MODEL,
controlnet=controlnet,
vae=vae,
safety_checker=None,
torch_dtype=torch.float16,
).to("cuda")
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
# Sampler map
SAMPLER_MAP = {
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
}
def center_crop_resize(img, output_size=(512, 512)):
width, height = img.size
# Calculate dimensions to crop to the center
new_dimension = min(width, height)
left = (width - new_dimension)/2
top = (height - new_dimension)/2
right = (width + new_dimension)/2
bottom = (height + new_dimension)/2
# Crop and resize
img = img.crop((left, top, right, bottom))
img = img.resize(output_size)
return img
def common_upscale(samples, width, height, upscale_method, crop=False):
if crop == "center":
old_width = samples.shape[3]
old_height = samples.shape[2]
old_aspect = old_width / old_height
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
s = samples[:,:,y:old_height-y,x:old_width-x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def upscale(samples, upscale_method, scale_by):
#s = samples.copy()
width = round(samples["images"].shape[3] * scale_by)
height = round(samples["images"].shape[2] * scale_by)
s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
return (s)
#
def convert_to_pil(base64_image):
pil_image = processing_utils.decode_base64_to_image(base64_image)
return pil_image
def convert_to_base64(pil_image):
base64_image = processing_utils.encode_pil_to_base64(pil_image)
return base64_image
def inference(
control_image: Image.Image,
prompt: str,
negative_prompt: str,
guidance_scale: float = 8.0,
controlnet_conditioning_scale: float = 1,
control_guidance_start: float = 1,
control_guidance_end: float = 1,
upscaler_strength: float = 0.5,
seed: int = -1,
sampler = "DPM++ Karras SDE",
#profile: gr.OAuthProfile | None = None,
):
try:
# Log input types and values
logger.debug("Input Types: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s",
type(control_image), type(prompt), type(negative_prompt), type(guidance_scale), type(controlnet_conditioning_scale),
type(control_guidance_start), type(control_guidance_end), type(upscaler_strength), type(seed), type(sampler))
logger.debug("Input Values: control_image=%s, prompt=%s, negative_prompt=%s, guidance_scale=%s, controlnet_conditioning_scale=%s, control_guidance_start=%s, control_guidance_end=%s, upscaler_strength=%s, seed=%s, sampler=%s",
control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale,
control_guidance_start, control_guidance_end, upscaler_strength, seed, sampler)
start_time = time.time()
start_time_struct = time.localtime(start_time)
start_time_formatted = time.strftime("%H:%M:%S", start_time_struct)
logger.info(f"Inference started at {start_time_formatted}")
# Generate the initial image
#init_image = init_pipe(prompt).images[0]
# Rest of your existing code
control_image_small = center_crop_resize(control_image)
control_image_large = center_crop_resize(control_image, (1024, 1024))
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else seed
generator = torch.Generator(device="cuda").manual_seed(my_seed)
out = main_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image_small,
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
generator=generator,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
num_inference_steps=15,
output_type="latent"
)
upscaled_latents = upscale(out, "nearest-exact", 2)
out_image = image_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
control_image=control_image_large,
image=upscaled_latents,
guidance_scale=float(guidance_scale),
generator=generator,
num_inference_steps=20,
strength=upscaler_strength,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
controlnet_conditioning_scale=float(controlnet_conditioning_scale)
)
end_time = time.time()
end_time_struct = time.localtime(end_time)
end_time_formatted = time.strftime("%H:%M:%S", end_time_struct)
print(f"Inference ended at {end_time_formatted}, taking {end_time-start_time}s")
logger.debug("Output Types: generated_image=%s", type(None))
logger.debug("Output Values: generated_image=None")
print("Type of output image:", type(out_image["images"][0]))
return out_image["images"][0]
except Exception as e:
# Handle exceptions and log error message
logger.error("Error occurred during inference: %s", str(e))
return str(e)
def generate_image_from_parameters(prompt, guidance_scale, controlnet_scale, controlnet_end, upscaler_strength, seed, sampler_type, image):
try:
# Save the uploaded image to a temporary file
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
with open(temp_image_path, "wb") as temp_image:
temp_image.write(image.file.read())
# Open the uploaded image using PIL
control_image = Image.open(temp_image_path)
# Call existing inference function with the provided parameters
generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type)
# Save the generated image as binary data
output_image_io = io.BytesIO()
generated_image.save(output_image_io, format="PNG")
output_image_io.seek(0)
output_image_binary = output_image_io.read()
# Return the generated image binary data
logger.debug("Output Values: generated_image=<binary data>")
return output_image_binary
except Exception as e:
# Handle exceptions and return an error message if something goes wrong
return str(e)
@app.post("/generate_image")
async def generate_image(
prompt: str = Form(...),
guidance_scale: float = Form(...),
controlnet_scale: float = Form(...),
controlnet_end: float = Form(...),
upscaler_strength: float = Form(...),
seed: int = Form(...),
sampler_type: str = Form(...),
image: UploadFile = File(...)
):
try:
# Save the uploaded image to a temporary file
temp_image_path = f"/tmp/{int(time.time())}_{image.filename}"
with open(temp_image_path, "wb") as temp_image:
temp_image.write(image.file.read())
# Open the uploaded image using PIL
control_image = Image.open(temp_image_path)
# Call existing inference function with the provided parameters
generated_image, _, _, _ = inference(control_image, prompt, "", guidance_scale, controlnet_scale, 0, controlnet_end, upscaler_strength, seed, sampler_type)
# Convert the PIL Image to bytes
img_bytes = io.BytesIO()
generated_image.save(img_bytes, format='PNG')
img_bytes.seek(0)
# Return the generated image as response
return FileResponse(img_bytes, media_type='image/png', headers={'Content-Disposition': 'inline; filename=generated_image.png'})
except Exception as e:
# Handle exceptions and return an error message if something goes wrong
return str(e)
if __name__ == "__main__":
import uvicorn
# Get internal IP address
internal_ip = socket.gethostbyname(socket.gethostname())
# Get public IP address using a public API (this may not work if you are behind a router/NAT)
try:
public_ip = requests.get("http://api.ipify.org").text
except requests.RequestException:
public_ip = "Not Available"
print(f"Internal URL: http://{internal_ip}:7860")
print(f"Public URL: http://{public_ip}:7860")
uvicorn.run(app, host="0.0.0.0", port=7860, reload=True)