text2tag-llm / genimage.py
John6666's picture
Super-squash branch 'main' using huggingface_hub
6bba96a verified
raw
history blame
No virus
2.24 kB
import spaces
def load_pipeline():
from diffusers import DiffusionPipeline
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pipe = DiffusionPipeline.from_pretrained(
"John6666/rae-diffusion-xl-v2-sdxl-spo-dpo-turbo",
custom_pipeline="lpw_stable_diffusion_xl",
torch_dtype=torch.float16,
)
pipe.to(device)
return pipe
def save_image(image, metadata, output_dir):
import os
import uuid
import json
from PIL import PngImagePlugin
filename = str(uuid.uuid4()) + ".png"
os.makedirs(output_dir, exist_ok=True)
filepath = os.path.join(output_dir, filename)
metadata_str = json.dumps(metadata)
info = PngImagePlugin.PngInfo()
info.add_text("metadata", metadata_str)
image.save(filepath, "PNG", pnginfo=info)
return filepath
pipe = load_pipeline()
@spaces.GPU
def generate_image(prompt, neg_prompt):
metadata = {
"prompt": prompt,
"negative_prompt": neg_prompt,
"resolution": f"{1024} x {1024}",
"guidance_scale": 2,
"num_inference_steps": 16,
"sampler": "LCM",
}
try:
images = pipe(
prompt=prompt,
prompt_2="anime artwork, anime style, vibrant, studio anime, highly detailed, masterpiece, best quality, very aesthetic, absurdres",
negative_prompt=neg_prompt,
negative_prompt_2="lowres, (bad), text, error, fewer, extra, missing, worst quality, jpeg artifacts, low quality, watermark, unfinished, displeasing, oldest, early, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], photo, deformed, disfigured, low contrast, photo, deformed, disfigured, low contrast",
width=1024,
height=1024,
guidance_scale=2,
num_inference_steps=16,
output_type="pil",
clip_skip=1,
).images
if images:
image_paths = [
save_image(image, metadata, "./outputs")
for image in images
]
return image_paths
except Exception as e:
return []