diffusezx / app.py
sl
Steps parametrized.
7241f9b
import gradio as gr
from diffusers import DiffusionPipeline
import cv2
import torch
import numpy as np
from PIL import Image
import os
model = "shadowlamer/sd-zxspectrum-model-256"
image_width = 256
image_height = 192
samples_dir = "/tmp"
pipe = DiffusionPipeline.from_pretrained(model, safety_checker=None, requires_safety_checker=False)
# Borrowed from here: https://stackoverflow.com/a/73667318
def quantize_to_palette(_image, _palette):
x_query = _image.reshape(-1, 3).astype(np.float32)
x_index = _palette.astype(np.float32)
knn = cv2.ml.KNearest_create()
knn.train(x_index, cv2.ml.ROW_SAMPLE, np.arange(len(_palette)))
ret, results, neighbours, dist = knn.findNearest(x_query, 1)
_quantized_image = np.array([_palette[idx] for idx in neighbours.astype(int)])
_quantized_image = _quantized_image.reshape(_image.shape)
return Image.fromarray(cv2.cvtColor(np.array(_quantized_image, dtype=np.uint8), cv2.COLOR_BGR2RGB))
def collect_char_colors(image, _x, _y):
_colors = {}
for _char_y in range(8):
for _char_x in range(8):
_pixel = image.getpixel((_x + _char_x, _y + _char_y))
_colors[_pixel] = 1 if _pixel not in _colors else _colors[_pixel] + 1
_colors = sorted(_colors.items(), key=lambda _v: _v[1], reverse=True)
return [list(_tuple[0]) for _tuple in list(_colors)]
def palette_to_attr(_palette):
if len(_palette) == 0:
return 0x00
_attr = 0x40
_paper = _palette[0]
if _paper[0] > 0:
_attr = _attr + 0x10 # r
if _paper[1] > 0:
_attr = _attr + 0x20 # g
if _paper[2] > 0:
_attr = _attr + 0x08 # b
if len(_palette) == 1:
return _attr
_ink = _palette[1]
if _ink[0] > 0:
_attr = _attr + 0x02 # r
if _ink[1] > 0:
_attr = _attr + 0x04 # g
if _ink[2] > 0:
_attr = _attr + 0x01 # b
return _attr
def generate(prompt, seed, steps):
generator = torch.Generator("cpu").manual_seed(int(seed))
raw_image = \
pipe(prompt, height=image_height, width=image_width, num_inference_steps=int(steps), generator=generator).images[0]
palette = np.array(
[[0, 0, 0], [0, 0, 255], [0, 255, 0], [0, 255, 255], [255, 0, 0], [255, 0, 255], [255, 255, 0],
[255, 255, 255]])
input_image = np.array(raw_image)
input_image = input_image[:, :, ::-1].copy()
image = quantize_to_palette(_image=input_image, _palette=palette)
out = samples_dir + "/" + prompt.replace(" ", "_") + "_" + str(seed) + ".scr"
if not os.path.exists(out):
byte_buffer = [0] * 0x1800
attr_buffer = [0b00111000] * 0x300
for y in range(0, image_height, 8):
for x in range(0, image_width, 8):
px = int(x / 8)
py = int(y / 8)
palette = collect_char_colors(image, x, y)
byte_index = int(py / 8) * 0x800 + (py % 8) * 32 + px
for cy in range(8):
byte = 0
for cx in range(8):
byte = byte * 2
pixel = list(image.getpixel((x + cx, y + cy)))
if palette[0] != pixel:
byte = byte + 1
byte_buffer[byte_index] = byte
byte_index = byte_index + 0x100
attr = palette_to_attr(palette)
attr_buffer[py * 32 + px] = attr
scr = open(out, 'wb')
scr.write(bytearray(byte_buffer))
scr.write(bytearray(attr_buffer))
scr.close()
return [image, out]
iface = gr.Interface(fn=generate,
title="ZX-Spectrum inspired images generator ",
inputs=["text", "number", "number"],
outputs=["image", "file"],
examples=[["Cute cat", 123, 20], ["Solar system", 123, 20], ["Disco ball", 123, 20]])
iface.launch()