yamildiego's picture
test float 16
1ad6f34
raw
history blame
No virus
6.04 kB
import cv2
import torch
import random
import numpy as np
from PIL import Image
from pathlib import Path
import requests
from io import BytesIO
from huggingface_hub import hf_hub_download, snapshot_download
from ip_adapter.ip_adapter import IPAdapterXL
from safetensors.torch import load_file
import os
from diffusers import (
ControlNetModel,
StableDiffusionXLControlNetPipeline,
EulerDiscreteScheduler
)
# global variable
MAX_SEED = np.iinfo(np.int32).max
device = "cuda" if torch.cuda.is_available() else "cpu"
# dtype = torch.float16 if str(device).__contains__("cuda") else torch.float16
# device = torch.device("cpu")
dtype = torch.float16
# initialization
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
# base_model_path = "Yamer-AI/SDXL_Unstable_Diffusers"
# image_encoder_path = "sdxl_models/image_encoder"
# ip_ckpt = "sdxl_models/ip-adapter_sdxl.bin"
controlnet_path = "diffusers/controlnet-canny-sdxl-1.0"
class EndpointHandler():
def __init__(self, model_dir):
repo_id = "h94/IP-Adapter"
# repo_path = snapshot_download(repo_id="Yamer-AI/SDXL_Unstable_Diffusers")
# print(f"Repositorio clonado en: {repo_path}")
local_repo_path = snapshot_download(repo_id=repo_id)
self.image_encoder_local_path = os.path.join(local_repo_path, "sdxl_models", "image_encoder")
# self.image_encoder_local_path = os.path.join("sdxl_models", "image_encoder")
self.ip_ckpt = os.path.join("sdxl_models", "ip-adapter_sdxl.safetensors")
self.controlnet = ControlNetModel.from_pretrained(
controlnet_path, use_safetensors=False, torch_dtype=torch.float16
).to(device)
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=self.controlnet,
torch_dtype=torch.float16,
variant="fp16",
add_watermarker=False,
).to(device)
self.pipe.set_progress_bar_config(disable=True)
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config, timestep_spacing="trailing", prediction_type="epsilon"
)
state_dict = load_file(
hf_hub_download(
"ByteDance/SDXL-Lightning", "sdxl_lightning_2step_unet.safetensors"
)
)
self.pipe.unet.load_state_dict(state_dict)
self.pipe.unet.to(device)
self.ip_model = IPAdapterXL(
self.pipe,
self.image_encoder_local_path,
self.ip_ckpt,
device,
target_blocks=["up_blocks.0.attentions.1"],
)
def __call__(self, data):
def create_image(
image_pil,
input_image,
prompt,
n_prompt,
scale,
control_scale,
guidance_scale,
num_inference_steps,
seed,
neg_content_prompt=None,
neg_content_scale=0,
):
# seed = random.randint(0, MAX_SEED) if seed == -1 else seed
print(f"Seed: {seed}")
# # if input_image is not None:
# input_image = resize_img(input_image, max_side=1024)
# cv_input_image = pil_to_cv2(input_image)
# detected_map = cv2.Canny(cv_input_image, 50, 200)
# canny_map = Image.fromarray(cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB))
# # else:
canny_map = Image.new("RGB", (1024, 1024), color=(255, 255, 255))
control_scale = 0
# if float(control_scale) == 0:
# canny_map = canny_map.resize((1024, 1024))
# if len(neg_content_prompt) > 0 and neg_content_scale != 0:
# images = self.ip_model.generate(
# pil_image=image_pil,
# prompt=prompt,
# negative_prompt=n_prompt,
# scale=scale,
# guidance_scale=guidance_scale,
# num_samples=1,
# num_inference_steps=num_inference_steps,
# seed=seed,
# image=canny_map,
# controlnet_conditioning_scale=float(control_scale),
# neg_content_prompt=neg_content_prompt,
# neg_content_scale=neg_content_scale,
# )
# else:
print("Creating image... (inside create_image function)")
images = self.ip_model.generate(
pil_image=image_pil,
prompt=prompt,
negative_prompt=n_prompt,
scale=scale,
guidance_scale=guidance_scale,
num_samples=1,
num_inference_steps=num_inference_steps,
seed=seed,
image=canny_map,
controlnet_conditioning_scale=float(control_scale),
)
image = images[0]
return image
style_image_url = "https://i.ibb.co/yNjNksz/chrome-xaz-Ic-SNk-SY-1.jpg"
response = requests.get(style_image_url)
style_image_pil = Image.open(BytesIO(response.content))
print("Images loaded...")
source_image =None
prompt = "A Art Deco style artwork of tennis court with tennis balls and rackets, a modern car, (featuring Notre Dame Cathedral:1.5)"
scale =0.3
control_scale =0.0
try:
print("Creating image... (outside try block)")
return create_image(
image_pil=style_image_pil,
input_image=source_image,
prompt=prompt,
n_prompt="",
scale=scale,
control_scale=control_scale,
guidance_scale=0.0,
num_inference_steps=25,
seed=1109176307,
neg_content_prompt="",
neg_content_scale=0,
)
except Exception as e:
return None