| """
|
| PixelArtGen — Generate pixel art from text prompts.
|
|
|
| Usage:
|
| python generate.py --prompt "a red pixel art sword" --output output.png
|
| python generate.py --prompt "a blue pixel art heart" --output heart.png --temperature 0.7
|
| python generate.py --batch-prompts prompts.txt --output-dir outputs/
|
| """
|
|
|
| import os
|
| import sys
|
| import json
|
| import argparse
|
| import numpy as np
|
| import torch
|
| from pathlib import Path
|
| from PIL import Image
|
|
|
| sys.path.insert(0, str(Path(__file__).parent))
|
|
|
| from model.tokenizer import PaletteTokenizer
|
| from model.text_encoder import TextTokenizer, TextEncoder
|
| from model.pixel_decoder import PixelLMDecoder, PixelLM
|
|
|
|
|
| def load_model(checkpoint_path: str, data_dir: str, device: torch.device):
|
| """Load a trained PixelLM model from checkpoint."""
|
| data_dir = Path(data_dir)
|
|
|
|
|
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| model_args = checkpoint.get("args", {})
|
|
|
|
|
| palette_tok = PaletteTokenizer(palette_path=str(data_dir / "palette_256.npy"))
|
|
|
| with open(data_dir / "vocab.json") as f:
|
| vocab = json.load(f)
|
| text_tok = TextTokenizer(vocab)
|
|
|
|
|
| d_model = model_args.get("d_model", 256)
|
| nhead = model_args.get("nhead", 8)
|
| text_layers = model_args.get("text_layers", 3)
|
| pixel_layers = model_args.get("pixel_layers", 6)
|
| dim_ff = model_args.get("dim_ff", 512)
|
| dropout = model_args.get("dropout", 0.1)
|
| max_text_len = model_args.get("max_text_len", 32)
|
|
|
| text_encoder = TextEncoder(
|
| vocab_size=text_tok.vocab_size,
|
| d_model=d_model,
|
| nhead=nhead,
|
| num_layers=text_layers,
|
| dim_feedforward=dim_ff,
|
| max_seq_len=max_text_len,
|
| dropout=dropout,
|
| )
|
|
|
| pixel_decoder = PixelLMDecoder(
|
| vocab_size=palette_tok.vocab_size,
|
| d_model=d_model,
|
| nhead=nhead,
|
| num_layers=pixel_layers,
|
| dim_feedforward=dim_ff,
|
| img_size=32,
|
| dropout=dropout,
|
| )
|
|
|
| model = PixelLM(text_encoder, pixel_decoder).to(device)
|
| model.load_state_dict(checkpoint["model_state_dict"])
|
| model.eval()
|
|
|
| return model, palette_tok, text_tok
|
|
|
|
|
| def generate_pixel_art(
|
| model: PixelLM,
|
| palette_tok: PaletteTokenizer,
|
| text_tok: TextTokenizer,
|
| prompt: str,
|
| device: torch.device,
|
| temperature: float = 0.8,
|
| top_k: int = 40,
|
| top_p: float = 0.9,
|
| scale: int = 8,
|
| ) -> Image.Image:
|
| """
|
| Generate a 32×32 pixel art image from a text prompt.
|
|
|
| Args:
|
| model: Trained PixelLM model
|
| palette_tok: Color palette tokenizer
|
| text_tok: Text tokenizer
|
| prompt: Text description
|
| device: torch device
|
| temperature: Sampling temperature (lower = more deterministic)
|
| top_k: Top-k filtering
|
| top_p: Nucleus sampling threshold
|
| scale: Upscale factor for display (8 = 256×256 output)
|
| Returns:
|
| PIL Image (32*scale × 32*scale)
|
| """
|
|
|
| text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
|
|
|
|
|
| with torch.no_grad():
|
| generated_tokens = model.generate(
|
| text_tokens,
|
| sos_token=palette_tok.sos_token,
|
| eos_token=palette_tok.eos_token,
|
| temperature=temperature,
|
| top_k=top_k,
|
| top_p=top_p,
|
| )
|
|
|
|
|
| token_list = generated_tokens[0].cpu().tolist()
|
| img_array = palette_tok.decode_tokens(token_list)
|
| img = Image.fromarray(img_array, "RGB")
|
|
|
|
|
| if scale > 1:
|
| img = img.resize((32 * scale, 32 * scale), Image.NEAREST)
|
|
|
| return img
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Generate pixel art from text")
|
| parser.add_argument("--prompt", type=str, help="Text prompt")
|
| parser.add_argument("--output", type=str, default="output.png", help="Output file")
|
| parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt")
|
| parser.add_argument("--data-dir", type=str, default=r"D:\PixelArtGen_Data\processed")
|
| parser.add_argument("--temperature", type=float, default=0.8)
|
| parser.add_argument("--top-k", type=int, default=40)
|
| parser.add_argument("--top-p", type=float, default=0.9)
|
| parser.add_argument("--scale", type=int, default=8, help="Upscale factor")
|
| parser.add_argument("--num-samples", type=int, default=1, help="Number of images to generate")
|
| parser.add_argument("--batch-prompts", type=str, help="File with prompts (one per line)")
|
| parser.add_argument("--output-dir", type=str, default="outputs")
|
|
|
| args = parser.parse_args()
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Device: {device}")
|
|
|
|
|
| print(f"Loading model from {args.checkpoint}...")
|
| model, palette_tok, text_tok = load_model(args.checkpoint, args.data_dir, device)
|
| print(f" Model: {model.count_parameters():,} parameters")
|
|
|
|
|
| if args.batch_prompts:
|
| with open(args.batch_prompts) as f:
|
| prompts = [line.strip() for line in f if line.strip()]
|
| elif args.prompt:
|
| prompts = [args.prompt]
|
| else:
|
| prompts = [
|
| "a red pixel art sword",
|
| "a blue pixel art heart",
|
| "a green pixel art tree",
|
| "a purple pixel art gem",
|
| ]
|
|
|
|
|
| output_dir = Path(args.output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| for i, prompt in enumerate(prompts):
|
| print(f"\nGenerating: \"{prompt}\"")
|
| for j in range(args.num_samples):
|
| img = generate_pixel_art(
|
| model, palette_tok, text_tok, prompt, device,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| top_p=args.top_p,
|
| scale=args.scale,
|
| )
|
|
|
| if len(prompts) == 1 and args.num_samples == 1:
|
| out_path = args.output
|
| else:
|
| safe_name = prompt.replace(" ", "_")[:30]
|
| out_path = output_dir / f"{safe_name}_{j}.png"
|
|
|
| img.save(str(out_path))
|
| print(f" Saved: {out_path}")
|
|
|
| print("\nDone!")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|