Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import re | |
| import time | |
| import random | |
| import torch | |
| from typing import Dict, Final, List, Optional, Tuple, cast | |
| from PIL import Image, ImageDraw, ImageEnhance | |
| from PIL.Image import Image as PILImage | |
| from diffusers import StableDiffusionPipeline | |
| model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, torch_dtype=torch.float16, cache_dir="cache" | |
| ) | |
| pipe = pipe.to("cuda") | |
| sprite_sides: Final = { | |
| "front": "PixelArtFSS", | |
| "right": "PixelArtRSS", | |
| "back": "PixelArtBSS", | |
| "left": "PixelArtLSS", | |
| } | |
| def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]: | |
| seed = seed or random.randrange(0, max) | |
| return torch.Generator("cuda").manual_seed(seed), seed | |
| def generate( | |
| prompt: str, | |
| sfw_retries: int = 1, | |
| seed: Optional[int] = None, | |
| ) -> PILImage: | |
| """ | |
| Generate a sprite image from a text description. | |
| Return a blank image if the model fails to generate a safe image. | |
| """ | |
| generator = torchGenerator(seed)[0] | |
| image: PILImage | None = None | |
| for _ in range(sfw_retries): | |
| pipe_output = pipe(prompt, generator=generator, width=512, height=512) | |
| image = pipe_output.images[0] | |
| if not pipe_output.nsfw_content_detected[0]: | |
| break | |
| rand_seed = seed | |
| while rand_seed == seed: | |
| print(f"Regenerating `{prompt}` with different seed.") | |
| rand_seed = random.randrange(0, 1024) | |
| generator = torchGenerator(rand_seed)[0] | |
| return cast(PILImage, image) | |
| def generate_sides( | |
| prompt: str, sfw_retries: int = 1, sides: Dict[str, str] = sprite_sides | |
| ) -> Tuple[Dict[str, PILImage], str]: | |
| """ | |
| Generate sprite images from a text description of different sides. | |
| If both left and right side specified, duplicate and flip left side as the right side | |
| """ | |
| print(f"Generating sprites for `{prompt}`") | |
| seed = random.randrange(0, 1024) | |
| sprites = {} | |
| # If both left and right side specified, duplicate and flip left side as the right side | |
| for side, label in sides.items(): | |
| if side == "right" and "left" in sides and "right" in sides: | |
| continue | |
| sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed) | |
| if "left" in sides and "right" in sides: | |
| sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT) | |
| return sprites, prompt | |
| def clean_sprite( | |
| image: PILImage, | |
| size: Tuple[int, int] = (192, 192), | |
| sharpness: float = 1.5, | |
| thresh: int = 128, | |
| rescaling: Optional[int] = None, | |
| ) -> PILImage: | |
| """ | |
| Process image to be more sprite-like. | |
| `rescale` will first scale down by value, then up to specified size. | |
| """ | |
| width, height = image.size | |
| sharpener = ImageEnhance.Sharpness(image) | |
| image = sharpener.enhance(sharpness) | |
| image = image.convert("RGBA") | |
| ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh) | |
| if type(rescaling) is int: | |
| image = image.resize( | |
| (int(width / rescaling), int(height / rescaling)), | |
| resample=Image.Resampling.NEAREST, | |
| ) | |
| image = image.resize(size, resample=Image.Resampling.NEAREST) | |
| return image | |
| def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]: | |
| """Split sprite image into individual sides.""" | |
| width, height = image.size | |
| w, h = size | |
| # fmt: off | |
| frames = [ | |
| image.crop(( | |
| 0, | |
| int(h / 2), | |
| int(width / 4), | |
| int(height * 0.75), | |
| )), | |
| image.crop(( | |
| int(width / 4), | |
| int(h / 2), | |
| int(width / 4) * 2, | |
| int(height * 0.75), | |
| )), | |
| image.crop(( | |
| int(width / 4) * 2, | |
| int(h / 2), | |
| int(width / 4) * 3, | |
| int(height * 0.75), | |
| )), | |
| image.crop(( | |
| int(width / 4) * 3, | |
| int(h / 2), | |
| width, | |
| int(height * 0.75), | |
| )), | |
| ] | |
| # fmt: on | |
| new_canvas = Image.new("RGBA", size, (255, 255, 255, 0)) | |
| for i in range(len(frames)): | |
| canvas = new_canvas.copy() | |
| canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h)) | |
| frames[i] = canvas | |
| return frames | |
| def build_spritesheet( | |
| images: Dict[str, PILImage], | |
| text: str = "sd_pixelart", | |
| sprite_size: Tuple[int, int] = (96, 96), | |
| dir: str = "output", | |
| save: bool = False, | |
| timestamp: Optional[int] = None, | |
| thresh: int = 128, | |
| ) -> Tuple[PILImage, str | None]: | |
| """ | |
| Build sprite sheet from sides. | |
| 1. Clean and scale each image | |
| 2. Split each image into individual frames | |
| 3. Create a new spritesheet canvas for all sides[frames] | |
| 4. Paste each individial frame onto canvas | |
| """ | |
| frames = {} | |
| width, height = sprite_size | |
| text = re.sub(r"[^\w()[\]_-]", "", text) | |
| filepath = None | |
| for side, image in images.items(): | |
| image = clean_sprite(image, (width * 2, height * 2), thresh=thresh) | |
| frames[side] = split_sprites(image, sprite_size) | |
| canvas = Image.new( | |
| "RGBA", | |
| (width * len(frames["front"]), height * len(frames)), | |
| (255, 255, 255, 0), | |
| ) | |
| for j in range(len(frames["front"])): | |
| for k, side in enumerate(frames): | |
| canvas.paste( | |
| frames[side][j], | |
| ( | |
| j * width, | |
| k * height, | |
| j * width + width, | |
| k * height + height, | |
| ), | |
| ) | |
| spritesheet = io.BytesIO() | |
| canvas.save(spritesheet, "PNG") | |
| if save: | |
| timestamp = timestamp or int(time.time()) | |
| filepath = os.path.join(dir, f"{timestamp}_{text}.png") | |
| canvas.save(filepath) | |
| return Image.open(spritesheet), filepath | |
| def build_gifs( | |
| images: Dict[str, PILImage], | |
| text: str = "sd_spritesheet", | |
| dir: str = "output", | |
| duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450), | |
| save: bool = False, | |
| timestamp: Optional[int] = None, | |
| thresh: int = 128, | |
| ) -> Tuple[Dict[str, List[PILImage]], List[str] | None]: | |
| """Build animated GIFs from side frames.""" | |
| gifs = {} | |
| text = re.sub(r"[^\w()[\]_-]", "", text) | |
| filepaths = [] if save else None | |
| for side, image in images.items(): | |
| image = clean_sprite(image, thresh=thresh) | |
| frames = split_sprites(image) | |
| gif = io.BytesIO() | |
| options = { | |
| "fp": gif, | |
| "format": "GIF", | |
| "save_all": True, | |
| "append_images": frames[1:], | |
| "disposal": 3, | |
| "duration": duration, | |
| "loop": 0, | |
| } | |
| frames[0].save(**options) | |
| gifs[side] = Image.open(gif) | |
| if save: | |
| timestamp = timestamp or int(time.time()) | |
| filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif") | |
| filepaths.append(filepath) | |
| options.update({"fp": filepath}) | |
| frames[0].save(**options) | |
| return gifs, filepaths | |