|
import spaces |
|
import os |
|
import json |
|
import time |
|
import copy |
|
import numpy as np |
|
import torch |
|
import random |
|
from diffusers import AutoPipelineForText2Image, StableDiffusionPipeline,DiffusionPipeline, StableDiffusionXLPipeline, AutoencoderKL, AutoencoderTiny, UNet2DConditionModel |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
from pathlib import Path |
|
from diffusers import EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DPMSolverSDEScheduler |
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding |
|
from cryptography.hazmat.primitives import serialization, hashes |
|
from cryptography.hazmat.backends import default_backend |
|
from cryptography.hazmat.primitives.asymmetric import utils |
|
import base64 |
|
import json |
|
import ipown |
|
import jwt |
|
import glob |
|
import traceback |
|
from insightface.app import FaceAnalysis |
|
import cv2 |
|
import re |
|
import gradio as gr |
|
import uuid |
|
from PIL import Image |
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
VAR_PUBLIC_KEY = os.getenv('PUBLIC_KEY') |
|
DATASET_ID = 'nsfwalex/checkpoint_n_lora' |
|
scheduler_config = { |
|
"num_train_timesteps": 1000, |
|
"beta_start": 0.00085, |
|
"beta_end": 0.012, |
|
"beta_schedule": "scaled_linear", |
|
"set_alpha_to_one": False, |
|
"steps_offset": 1, |
|
"prediction_type": "epsilon", |
|
} |
|
samplers = { |
|
"Euler a": EulerAncestralDiscreteScheduler.from_config(scheduler_config), |
|
"DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(scheduler_config, use_karras_sigmas=True), |
|
"DPM2 a": DPMSolverMultistepScheduler.from_config(scheduler_config), |
|
"DPM++ SDE": DPMSolverSDEScheduler.from_config(scheduler_config), |
|
"DPM++ 2M SDE": DPMSolverSDEScheduler.from_config(scheduler_config, use_2m=True), |
|
"DPM++ 2S a": DPMSolverMultistepScheduler.from_config(scheduler_config, use_2s=True) |
|
} |
|
|
|
class AuthHelper: |
|
def load_public_key_from_file(self): |
|
public_key_bytes = VAR_PUBLIC_KEY.encode('utf-8') |
|
public_key = serialization.load_pem_public_key( |
|
public_key_bytes, |
|
backend=default_backend() |
|
) |
|
return public_key |
|
|
|
def __init__(self): |
|
self.public_key = self.load_public_key_from_file() |
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_jwt(self, token, algorithms=["RS256"]): |
|
""" |
|
Decode and verify a JWT using a public key. |
|
|
|
:param public_key: The public key used for verification. |
|
:param token: The JWT string to decode. |
|
:param algorithms: List of acceptable algorithms (default is ["RS256"]). |
|
:return: The decoded JWT payload if verification is successful. |
|
:raises: Exception if verification fails. |
|
""" |
|
try: |
|
|
|
decoded_payload = jwt.decode( |
|
token, |
|
self.public_key, |
|
algorithms=algorithms, |
|
options={"verify_signature": True} |
|
) |
|
return decoded_payload |
|
except Exception as e: |
|
print("Invalid token:", e) |
|
raise |
|
|
|
import hashlib |
|
|
|
def check_auth(self, request, token): |
|
|
|
if not request or request.query_params.get("_skip_token_passkey", "") == "nsfwaisio_125687": |
|
return True |
|
params = dict(request.query_params) |
|
|
|
sip = request.client.host |
|
shost = request.headers.get("Host", "") |
|
sreferer = request.headers.get("Referer", "") |
|
suseragent = request.headers.get("User-Agent", "") |
|
|
|
print(sip, shost, sreferer, suseragent) |
|
|
|
|
|
jwt_data = self.decode_jwt(token) |
|
jwt_auth = jwt_data.get("auth", "") |
|
|
|
if not jwt_auth: |
|
raise Exception("Missing auth field in token") |
|
|
|
|
|
auth_string = f"{sip}{shost}{sreferer}{suseragent}" |
|
calculated_md5 = hashlib.md5(auth_string.encode('utf-8')).hexdigest() |
|
|
|
print(f"Calculated MD5: {calculated_md5}, JWT Auth: {jwt_auth}") |
|
|
|
|
|
if calculated_md5 == jwt_auth: |
|
return True |
|
|
|
raise Exception("Invalid authentication") |
|
|
|
class InferenceManager: |
|
def __init__(self, config_path="config.json", ext_model_pathes={}): |
|
cfg = {} |
|
with open(config_path, "r", encoding="utf-8") as f: |
|
cfg = json.load(f) |
|
self.cfg = cfg |
|
self.ext_model_pathes = ext_model_pathes |
|
|
|
lora_options_path = cfg.get("loras", "") |
|
self.model_version = cfg["model_version"] |
|
self.lora_load_options = self.load_json(lora_options_path) |
|
self.lora_models = self.load_index_file("index.json") |
|
self.preloaded_loras = [] |
|
self.ip_adapter_faceid_pipeline = None |
|
self.base_model_pipeline = self.load_base_model() |
|
|
|
self.preload_loras() |
|
|
|
def load_json(self, filepath): |
|
"""Load JSON file into a dictionary.""" |
|
if os.path.exists(filepath): |
|
with open(filepath, "r", encoding="utf-8") as f: |
|
return json.load(f) |
|
return {} |
|
|
|
def load_index_file(self, index_file): |
|
"""Download index.json from Hugging Face and return the file path.""" |
|
index_path = download_from_hf(index_file) |
|
if index_path: |
|
with open(index_path, "r", encoding="utf-8") as f: |
|
return json.load(f) |
|
return {} |
|
|
|
@spaces.GPU(duration=40) |
|
def compile_onediff(self): |
|
self.base_model_pipeline.to("cuda") |
|
pipe = self.base_model_pipeline |
|
|
|
load_pipe(pipe, dir="cached_pipe") |
|
print("Start oneflow compiling...") |
|
start_compile = time.time() |
|
pipe = compile_pipe(pipe) |
|
|
|
image = pipe( |
|
prompt="street style, detailed, raw photo, woman, face, shot on CineStill 800T", |
|
height=512, |
|
width=512, |
|
num_inference_steps=10, |
|
output_type="pil", |
|
).images |
|
image[0].save(f"test_image.png") |
|
compile_time = time.time() - start_compile |
|
|
|
|
|
save_pipe(pipe, dir="cached_pipe") |
|
self.base_model_pipeline = pipe |
|
print(f"OneDiff compile in {compile_time}s") |
|
|
|
def load_base_model(self): |
|
"""Load base model and return the pipeline.""" |
|
start = time.time() |
|
cfg = self.cfg |
|
|
|
model_version = self.model_version |
|
ckpt_dir = snapshot_download(repo_id=cfg["model_id"], local_files_only=False) |
|
|
|
if model_version == "1.5": |
|
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16) |
|
pipe = StableDiffusionPipeline.from_pretrained(ckpt_dir, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True) |
|
else: |
|
use_vae = cfg.get("vae", "") |
|
if not use_vae or True: |
|
vae = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir, "vae"), torch_dtype=torch.bfloat16) |
|
elif use_vae == "tae": |
|
vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.bfloat16) |
|
else: |
|
vae = AutoencoderTiny.from_pretrained(use_vae, torch_dtype=torch.bfloat16) |
|
print(ckpt_dir) |
|
pipe = DiffusionPipeline.from_pretrained( |
|
ckpt_dir, |
|
vae=vae, |
|
|
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=True, |
|
|
|
custom_pipeline = "lpw_stable_diffusion_xl", |
|
) |
|
|
|
clip_skip = cfg.get("clip_skip", 1) |
|
|
|
pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1) |
|
|
|
load_time = round(time.time() - start, 2) |
|
print(f"Base model loaded in {load_time}s") |
|
|
|
if cfg.get("load_ip_adapter_faceid", False): |
|
if model_version in ("pony", "xl"): |
|
ip_ckpt = self.ext_model_pathes.get("ip-adapter-faceid-sdxl", "") |
|
if ip_ckpt: |
|
print(f"loading ip adapter model...") |
|
self.ip_adapter_faceid_pipeline = ipown.IPAdapterFaceIDXL(pipe, ip_ckpt, 'cuda', torch_dtype=torch.bfloat16) |
|
else: |
|
print("ip-adapter-faceid-sdxl not found, skip") |
|
|
|
return pipe |
|
|
|
|
|
def preload_loras(self): |
|
"""Preload all LoRAs marked as 'preload=True' and store for later use.""" |
|
for lora_name, lora_info in self.lora_load_options.items(): |
|
try: |
|
start = time.time() |
|
|
|
|
|
lora_index_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None) |
|
if not lora_index_info: |
|
raise ValueError(f"LoRA {lora_name} not found in index.json.") |
|
|
|
|
|
if self.model_version not in lora_info['base_model'] or not lora_info.get('preload', False): |
|
print(f"Skipping {lora_name} as it's not compatible with the current model version.") |
|
continue |
|
|
|
|
|
weight_path = download_from_hf(lora_index_info['path'], local_dir=None) |
|
if not weight_path: |
|
raise ValueError(f"Failed to download LoRA weights for {lora_name}") |
|
load_time = round(time.time() - start, 2) |
|
print(f"Downloaded {lora_name} in {load_time}s") |
|
self.base_model_pipeline.load_lora_weights( |
|
weight_path, |
|
weight_name=lora_index_info["path"], |
|
adapter_name=lora_name |
|
) |
|
|
|
|
|
if lora_info.get("preload", False): |
|
self.preloaded_loras.append({ |
|
"name": lora_name, |
|
"weight": lora_info.get("weight", 1.0) |
|
}) |
|
load_time = round(time.time() - start, 2) |
|
print(f"Preloaded LoRA {lora_name} with weight {lora_info.get('weight', 1.0)} in {load_time}s.") |
|
except Exception as e: |
|
print(f"Lora {lora_name} not loaded, skipping... {e}") |
|
|
|
def build_pipeline_with_lora(self, lora_list, sampler=None, new_pipeline=False): |
|
"""Build the pipeline with specific LoRAs, loading any that are not preloaded.""" |
|
|
|
start = time.time() |
|
if new_pipeline: |
|
temp_pipeline = copy.deepcopy(self.base_model_pipeline) |
|
else: |
|
temp_pipeline = self.base_model_pipeline |
|
copy_time = round(time.time() - start, 2) |
|
print(f"pipeline copied in {copy_time}s") |
|
|
|
dynamic_loras = [] |
|
|
|
|
|
for lora_name in lora_list: |
|
if not any(l['name'] == lora_name for l in self.preloaded_loras): |
|
lora_info = next((l for l in self.lora_models['lora'] if l['name'] == lora_name), None) |
|
if lora_info and self.model_version in lora_info["attr"].get("base_model", []): |
|
dynamic_loras.append({ |
|
"name": lora_name, |
|
"filename": lora_info["path"], |
|
"scale": 1.0 |
|
}) |
|
|
|
|
|
all_loras = [{"name": x["name"], "scale": x["weight"], "preloaded": True} for x in self.preloaded_loras] + dynamic_loras |
|
set_lora_weights(temp_pipeline, all_loras,False) |
|
|
|
build_time = round(time.time() - start, 2) |
|
print(f"Pipeline built with LoRAs in {build_time}s.") |
|
if not sampler: |
|
sampler = self.cfg.get("sampler", "Euler a") |
|
|
|
|
|
|
|
temp_pipeline.scheduler = samplers[sampler] |
|
|
|
|
|
temp_pipeline |
|
return temp_pipeline |
|
|
|
def release(self, temp_pipeline): |
|
"""Release the deepcopied pipeline to recycle memory.""" |
|
del temp_pipeline |
|
torch.cuda.empty_cache() |
|
print("Memory released and cache cleared.") |
|
|
|
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: |
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
return seed |
|
|
|
|
|
child_related_regex = re.compile( |
|
r'(child|children|kid|kids|baby|babies|toddler|infant|juvenile|minor|underage|preteen|adolescent|youngster|youth|son|daughter|young|kindergarten|preschool|' |
|
r'([1-9]|1[0-7])[\s_\-|\.\,]*year(s)?[\s_\-|\.\,]*old|' |
|
r'little|small|tiny|short|young|new[\s_\-|\.\,]*born[\s_\-|\.\,]*(boy|girl|male|man|bro|brother|sis|sister))', |
|
re.IGNORECASE |
|
) |
|
|
|
|
|
def remove_child_related_content(prompt): |
|
cleaned_prompt = re.sub(child_related_regex, '', prompt) |
|
return cleaned_prompt.strip() |
|
|
|
|
|
def contains_child_related_content(prompt): |
|
if child_related_regex.search(prompt): |
|
return True |
|
return False |
|
|
|
def save_image(img): |
|
path = "./tmp/" |
|
|
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
unique_name = str(uuid.uuid4()) + ".webp" |
|
unique_name = os.path.join(path, unique_name) |
|
|
|
|
|
webp_img = img.convert("RGB") |
|
|
|
|
|
webp_img.save(unique_name, "WEBP", quality=90) |
|
|
|
|
|
with Image.open(unique_name) as webp_file: |
|
webp_image = webp_file.copy() |
|
|
|
return webp_image, unique_name |
|
|
|
class ModelManager: |
|
def __init__(self, model_directory): |
|
""" |
|
Initialize the ModelManager by scanning all `.model.json` files in the given directory. |
|
|
|
:param model_directory: The directory to scan for model config files (e.g., "/path/to/models"). |
|
""" |
|
print("downloading models") |
|
print("loading face analysis...") |
|
self.app = None |
|
|
|
self.ext_model_pathes = { |
|
"ip-adapter-faceid-sdxl": hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sdxl.bin", repo_type="model") |
|
} |
|
|
|
self.models = {} |
|
self.ext_models = {} |
|
self.model_directory = model_directory |
|
self.load_models() |
|
|
|
|
|
def load_instant_x(self): |
|
|
|
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints") |
|
hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints") |
|
hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints") |
|
os.makedirs("./models",exist_ok=True) |
|
download_from_hf("models/antelopev2/1k3d68.onnx",local_dir="./models") |
|
download_from_hf("models/antelopev2/2d106det.onnx",local_dir="./models") |
|
download_from_hf("models/antelopev2/genderage.onnx",local_dir="./models") |
|
download_from_hf("models/antelopev2/glintr100.onnx",local_dir="./models") |
|
download_from_hf("models/antelopev2/scrfd_10g_bnkps.onnx",local_dir="./models") |
|
|
|
|
|
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
app.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
|
|
|
face_adapter = f'./checkpoints/ip-adapter.bin' |
|
controlnet_path = f'./checkpoints/ControlNetModel' |
|
|
|
def load_models(self): |
|
""" |
|
Scan the model directory for `.model.json` files and initialize InferenceManager instances for each one. |
|
|
|
:param model_directory: Directory to scan for `.model.json` files. |
|
""" |
|
model_files = glob.glob(os.path.join(self.model_directory, "*.model.json")) |
|
if not model_files: |
|
print(f"No model configuration files found in {self.model_directory}") |
|
return |
|
|
|
for file_path in model_files: |
|
model_name = self.get_model_name_from_url(file_path).split(".")[0] |
|
print(f"Initializing model: {model_name} from {file_path}") |
|
try: |
|
|
|
self.models[model_name] = InferenceManager(config_path=file_path, ext_model_pathes=self.ext_model_pathes) |
|
except Exception as e: |
|
print(traceback.format_exc()) |
|
print(f"Failed to initialize model {model_name} from {file_path}: {e}") |
|
|
|
def get_model_name_from_url(self, url): |
|
""" |
|
Extract the model name from the config file path (filename without extension). |
|
|
|
:param url: The file path of the configuration file. |
|
:return: The model name (file name without extension). |
|
""" |
|
filename = os.path.basename(url) |
|
model_name, _ = os.path.splitext(filename) |
|
return model_name |
|
|
|
def get_model_pipeline(self, model_id, lora_list, sampler=None, new_pipeline=False): |
|
""" |
|
Build the pipeline with specific LoRAs for a model. |
|
|
|
:param model_id: The model ID (the model name extracted from the config URL). |
|
:param lora_list: List of LoRAs to be applied to the model pipeline. |
|
:param sampler: The sampler to be used for the pipeline. |
|
:param new_pipeline: Flag to indicate whether to create a new pipeline or reuse the existing one. |
|
:return: The built pipeline with LoRAs applied. |
|
""" |
|
model = self.models.get(model_id) |
|
if not model: |
|
print(f"Model {model_id} not found.") |
|
return None |
|
try: |
|
print(f"Building pipeline with LoRAs for model {model_id}...") |
|
return model.build_pipeline_with_lora(lora_list, sampler, new_pipeline) |
|
except Exception as e: |
|
print(traceback.format_exc()) |
|
print(f"Failed to build pipeline for model {model_id}: {e}") |
|
return None |
|
|
|
def release_model(self, model_id): |
|
""" |
|
Release resources and clear memory for a specific model. |
|
|
|
:param model_id: The model ID (the model name extracted from the config URL). |
|
""" |
|
model = self.models.get(model_id) |
|
if not model: |
|
print(f"Model {model_id} not found.") |
|
return |
|
try: |
|
print(f"Releasing model {model_id}...") |
|
model.release(model.base_model_pipeline) |
|
except Exception as e: |
|
print(f"Failed to release model {model_id}: {e}") |
|
|
|
@spaces.GPU(duration=40) |
|
def generate_with_faceid(self, model_id, inference_params, progress=gr.Progress(track_tqdm=True)): |
|
|
|
torch.cuda.empty_cache() |
|
model = self.models.get(model_id) |
|
if not model: |
|
raise Exception(f"invalid model_id {model_id}") |
|
if not model.ip_adapter_faceid_pipeline: |
|
raise Exception(f"model does not support ip adapter") |
|
ip_model = model.ip_adapter_faceid_pipeline |
|
cfg = model.cfg |
|
p = inference_params.get("prompt") |
|
negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", "")) |
|
steps = inference_params.get("steps", cfg.get("inference_steps", 30)) |
|
guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7)) |
|
width = inference_params.get("width", cfg.get("width", 512)) |
|
height = inference_params.get("height", cfg.get("height", 512)) |
|
images = inference_params.get("images", []) |
|
likeness_strength = inference_params.get("likeness_strength", 0.4) |
|
face_strength = inference_params.get("face_strength", 0.1) |
|
sampler = inference_params.get("sampler", cfg.get("sampler", "")) |
|
lora_list = inference_params.get("loras", []) |
|
seed = inference_params.get("seed", 0) |
|
|
|
if not images: |
|
raise Exception(f"face images not provided") |
|
start = time.time() |
|
ip_model.pipe.to("cuda") |
|
if not self.app: |
|
self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider']) |
|
self.app.prepare(ctx_id=0, det_size=(512, 512)) |
|
print("extracting face...") |
|
faceid_all_embeds = [] |
|
for image in images: |
|
face = image |
|
faces = self.app.get(face) |
|
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) |
|
faceid_all_embeds.append(faceid_embed) |
|
|
|
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0) |
|
|
|
print("start inference...") |
|
style_selection = "" |
|
use_negative_prompt = True |
|
randomize_seed = True |
|
seed = seed or int(randomize_seed_fn(seed, randomize_seed)) |
|
p = remove_child_related_content(p) |
|
prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p) |
|
|
|
print(f"generate: p={p}, np={negative_prompt}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}") |
|
print(f"device: embedding={average_embedding.device}, ip_model={ip_model.pipe.device}, pipe={model.base_model_pipeline.device}") |
|
images = ip_model.generate( |
|
prompt=prompt_str, |
|
negative_prompt=negative_prompt, |
|
faceid_embeds=average_embedding, |
|
scale=likeness_strength, |
|
width=width, |
|
height=height, |
|
guidance_scale=face_strength, |
|
num_inference_steps=steps, |
|
|
|
num_images_per_prompt=1, |
|
|
|
|
|
|
|
).images |
|
cost = round(time.time() - start, 2) |
|
print(f"inference done in {cost}s") |
|
images = [save_image(img) for img in images] |
|
image_paths = [i[1] for i in images] |
|
print(prompt_str, image_paths) |
|
return [i[0] for i in images] |
|
|
|
@spaces.GPU(duration=40) |
|
def generate(self, model_id, inference_params, progress=gr.Progress(track_tqdm=True)): |
|
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): |
|
cfg_disabling_at = cfg.get('cfg_disabling_rate', 0.75) |
|
if step_index == int(pipe.num_timesteps * cfg_disabling_at): |
|
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] |
|
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] |
|
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] |
|
pipe._guidance_scale = 0.0 |
|
|
|
return callback_kwargs |
|
model = self.models.get(model_id) |
|
if not model: |
|
raise Exception(f"invalid model_id {model_id}") |
|
|
|
cfg = model.cfg |
|
p = inference_params.get("prompt") |
|
negative_prompt = inference_params.get("negative_prompt", cfg.get("negative_prompt", "")) |
|
steps = inference_params.get("steps", cfg.get("inference_steps", 30)) |
|
guidance_scale = inference_params.get("guidance_scale", cfg.get("guidance_scale", 7)) |
|
width = inference_params.get("width", cfg.get("width", 512)) |
|
height = inference_params.get("height", cfg.get("height", 512)) |
|
sampler = inference_params.get("sampler", cfg.get("sampler", "")) |
|
lora_list = inference_params.get("loras", []) |
|
seed = inference_params.get("seed", 0) |
|
|
|
pipe = model.build_pipeline_with_lora(lora_list, sampler) |
|
|
|
start = time.time() |
|
pipe.to("cuda") |
|
print("start inference...") |
|
style_selection = "" |
|
use_negative_prompt = True |
|
randomize_seed = True |
|
seed = seed or int(randomize_seed_fn(seed, randomize_seed)) |
|
guidance_scale = guidance_scale or cfg.get("guidance_scale", 7.5) |
|
p = remove_child_related_content(p) |
|
prompt_str = cfg.get("prompt", "{prompt}").replace("{prompt}", p) |
|
generator = torch.Generator(pipe.device).manual_seed(seed) |
|
print(f"generate: p={p}, np={negative_prompt}, steps={steps}, guidance_scale={guidance_scale}, size={width},{height}, seed={seed}") |
|
images = pipe( |
|
prompt=prompt_str, |
|
negative_prompt=negative_prompt, |
|
width=width, |
|
height=height, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=steps, |
|
generator=generator, |
|
num_images_per_prompt=1, |
|
output_type="pil", |
|
|
|
|
|
).images |
|
cost = round(time.time() - start, 2) |
|
print(f"inference done in {cost}s") |
|
images = [save_image(img) for img in images] |
|
image_paths = [i[1] for i in images] |
|
print(prompt_str, image_paths) |
|
return [i[0] for i in images] |
|
|
|
|
|
def download_from_hf(filename, local_dir=None, repo_id=DATASET_ID, repo_type="dataset"): |
|
try: |
|
file_path = hf_hub_download( |
|
filename=filename, |
|
repo_id=DATASET_ID, |
|
repo_type="dataset", |
|
revision="main", |
|
local_dir=local_dir, |
|
local_files_only=False, |
|
) |
|
return file_path |
|
except Exception as e: |
|
print(f"Failed to load {filename} from Hugging Face: {str(e)}") |
|
return None |
|
|
|
|
|
|
|
def set_lora_weights(pipe, lorajson: list[dict], fuse=False): |
|
try: |
|
if not lorajson or not isinstance(lorajson, list): |
|
return |
|
|
|
a_list = [] |
|
w_list = [] |
|
for d in lorajson: |
|
if not d or not isinstance(d, dict) or not d["name"] or d["name"] == "None": |
|
continue |
|
|
|
k = d["name"] |
|
if not d.get("preloaded", False): |
|
start = time.time() |
|
weight_path = download_from_hf(d['filename'], local_dir=None) |
|
if weight_path: |
|
pipe.load_lora_weights(weight_path, weight_name=d['filename'], adapter_name=k) |
|
|
|
load_time = round(time.time() - start, 2) |
|
print(f"LoRA {k} loaded in {load_time}s.") |
|
|
|
a_list.append(k) |
|
w_list.append(d["scale"]) |
|
|
|
if not a_list: |
|
return |
|
|
|
start = time.time() |
|
pipe.set_adapters(a_list, adapter_weights=w_list) |
|
if fuse: |
|
pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0) |
|
fuse_time = round(time.time() - start, 2) |
|
print(f"LoRAs fused in {fuse_time}s.") |
|
except Exception as e: |
|
print(f"External LoRA Error: {e}") |
|
raise Exception(f"External LoRA Error: {e}") from e |
|
|