ARX / handler.py
niofrequency's picture
Update handler.py
137b45f verified
Raw
History Blame Contribute Delete
3.15 kB
import torch
import base64
import io
from PIL import Image
from diffusers import StableDiffusionXLImg2ImgPipeline, DPMSolverMultistepScheduler
import os
class EndpointHandler():
def __init__(self, path=""):
print("Loading ARX Elite Pipeline (Juggernaut XL)...")
# 1. Point directly to the NEW model
model_path = os.path.join(path, "juggernaut.safetensors")
# 2. Load the core pipeline
self.pipe = StableDiffusionXLImg2ImgPipeline.from_single_file(
model_path,
torch_dtype=torch.float16,
use_safetensors=True,
safety_checker=None
)
# 3. Upgrade the Scheduler to DPM++ 2M Karras
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
use_karras_sigmas=True,
algorithm_type="sde-dpmsolver++"
)
# 4. Load IP-Adapter
self.pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter_sdxl.bin"
)
self.pipe.to("cuda")
print("ARX Elite Ready.")
def decode_base64_image(self, image_string):
if "," in image_string:
image_string = image_string.split(",")[1]
image_bytes = base64.b64decode(image_string)
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
def encode_image_base64(self, image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def __call__(self, data):
inputs = data.pop("inputs", data)
prompt = inputs.get("prompt", "masterpiece, best quality, highly detailed, photorealistic")
# Added realism enforcers to the negative prompt
negative_prompt = inputs.get("negative_prompt", "blurry, lowres, bad anatomy, worst quality, ugly, deformed eyes, cartoon, illustration, painting, 3d render")
strength = float(inputs.get("strength", 0.55))
guidance_scale = float(inputs.get("guidance_scale", 7.5))
num_inference_steps = int(inputs.get("steps", 30))
ip_adapter_scale = float(inputs.get("ip_adapter_scale", 0.75))
init_image_b64 = inputs.get("init_image")
ip_adapter_image_b64 = inputs.get("ip_adapter_image")
if not init_image_b64 or not ip_adapter_image_b64:
return {"error": "Missing image inputs."}
init_image = self.decode_base64_image(init_image_b64).resize((1024, 1024))
ip_image = self.decode_base64_image(ip_adapter_image_b64).resize((1024, 1024))
self.pipe.set_ip_adapter_scale(ip_adapter_scale)
# Generate!
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=init_image,
ip_adapter_image=ip_image,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps
).images[0]
return {"image": self.encode_image_base64(result)}