Ron Au
fix: Fix Dict types
69c5e48
raw
history blame
7.16 kB
import os
import io
import re
import time
import random
import torch
from typing import Dict, Final, List, Optional, Tuple, cast
from PIL import Image, ImageDraw, ImageEnhance
from PIL.Image import Image as PILImage
from diffusers import StableDiffusionPipeline
model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator"
pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, cache_dir="cache"
)
pipe = pipe.to("cuda")
sprite_sides: Final = {
"front": "PixelArtFSS",
"right": "PixelArtRSS",
"back": "PixelArtBSS",
"left": "PixelArtLSS",
}
def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]:
seed = seed or random.randrange(0, max)
return torch.Generator("cuda").manual_seed(seed), seed
def generate(
prompt: str,
sfw_retries: int = 1,
seed: Optional[int] = None,
) -> PILImage:
"""
Generate a sprite image from a text description.
Return a blank image if the model fails to generate a safe image.
"""
generator = torchGenerator(seed)[0]
image: PILImage | None = None
for _ in range(sfw_retries):
pipe_output = pipe(prompt, generator=generator, width=512, height=512)
image = pipe_output.images[0]
if not pipe_output.nsfw_content_detected[0]:
break
rand_seed = seed
while rand_seed == seed:
print(f"Regenerating `{prompt}` with different seed.")
rand_seed = random.randrange(0, 1024)
generator = torchGenerator(rand_seed)[0]
return cast(PILImage, image)
def generate_sides(
prompt: str, sfw_retries: int = 1, sides: Dict[str, str] = sprite_sides
) -> Tuple[Dict[str, PILImage], str]:
"""
Generate sprite images from a text description of different sides.
If both left and right side specified, duplicate and flip left side as the right side
"""
print(f"Generating sprites for `{prompt}`")
seed = random.randrange(0, 1024)
sprites = {}
# If both left and right side specified, duplicate and flip left side as the right side
for side, label in sides.items():
if side == "right" and "left" in sides and "right" in sides:
continue
sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed)
if "left" in sides and "right" in sides:
sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT)
return sprites, prompt
def clean_sprite(
image: PILImage,
size: Tuple[int, int] = (192, 192),
sharpness: float = 1.5,
thresh: int = 128,
rescaling: Optional[int] = None,
) -> PILImage:
"""
Process image to be more sprite-like.
`rescale` will first scale down by value, then up to specified size.
"""
width, height = image.size
sharpener = ImageEnhance.Sharpness(image)
image = sharpener.enhance(sharpness)
image = image.convert("RGBA")
ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh)
if type(rescaling) is int:
image = image.resize(
(int(width / rescaling), int(height / rescaling)),
resample=Image.Resampling.NEAREST,
)
image = image.resize(size, resample=Image.Resampling.NEAREST)
return image
def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]:
"""Split sprite image into individual sides."""
width, height = image.size
w, h = size
# fmt: off
frames = [
image.crop((
0,
int(h / 2),
int(width / 4),
int(height * 0.75),
)),
image.crop((
int(width / 4),
int(h / 2),
int(width / 4) * 2,
int(height * 0.75),
)),
image.crop((
int(width / 4) * 2,
int(h / 2),
int(width / 4) * 3,
int(height * 0.75),
)),
image.crop((
int(width / 4) * 3,
int(h / 2),
width,
int(height * 0.75),
)),
]
# fmt: on
new_canvas = Image.new("RGBA", size, (255, 255, 255, 0))
for i in range(len(frames)):
canvas = new_canvas.copy()
canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h))
frames[i] = canvas
return frames
def build_spritesheet(
images: Dict[str, PILImage],
text: str = "sd_pixelart",
sprite_size: Tuple[int, int] = (96, 96),
dir: str = "output",
save: bool = False,
timestamp: Optional[int] = None,
thresh: int = 128,
) -> Tuple[PILImage, str | None]:
"""
Build sprite sheet from sides.
1. Clean and scale each image
2. Split each image into individual frames
3. Create a new spritesheet canvas for all sides[frames]
4. Paste each individial frame onto canvas
"""
frames = {}
width, height = sprite_size
text = re.sub(r"[^\w()[\]_-]", "", text)
filepath = None
for side, image in images.items():
image = clean_sprite(image, (width * 2, height * 2), thresh=thresh)
frames[side] = split_sprites(image, sprite_size)
canvas = Image.new(
"RGBA",
(width * len(frames["front"]), height * len(frames)),
(255, 255, 255, 0),
)
for j in range(len(frames["front"])):
for k, side in enumerate(frames):
canvas.paste(
frames[side][j],
(
j * width,
k * height,
j * width + width,
k * height + height,
),
)
spritesheet = io.BytesIO()
canvas.save(spritesheet, "PNG")
if save:
timestamp = timestamp or int(time.time())
filepath = os.path.join(dir, f"{timestamp}_{text}.png")
canvas.save(filepath)
return Image.open(spritesheet), filepath
def build_gifs(
images: Dict[str, PILImage],
text: str = "sd_spritesheet",
dir: str = "output",
duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450),
save: bool = False,
timestamp: Optional[int] = None,
thresh: int = 128,
) -> Tuple[Dict[str, List[PILImage]], List[str] | None]:
"""Build animated GIFs from side frames."""
gifs = {}
text = re.sub(r"[^\w()[\]_-]", "", text)
filepaths = [] if save else None
for side, image in images.items():
image = clean_sprite(image, thresh=thresh)
frames = split_sprites(image)
gif = io.BytesIO()
options = {
"fp": gif,
"format": "GIF",
"save_all": True,
"append_images": frames[1:],
"disposal": 3,
"duration": duration,
"loop": 0,
}
frames[0].save(**options)
gifs[side] = Image.open(gif)
if save:
timestamp = timestamp or int(time.time())
filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif")
filepaths.append(filepath)
options.update({"fp": filepath})
frames[0].save(**options)
return gifs, filepaths