|
from PIL import Image |
|
from io import BytesIO |
|
import numpy as np |
|
import base64 |
|
import torch |
|
import torchvision.transforms.functional as F |
|
from S2I import Sketch2ImagePipeline |
|
|
|
|
|
|
|
class Sketch2ImageController(): |
|
def __init__(self, gr): |
|
super().__init__() |
|
self.gr = gr |
|
self.style_list = [ |
|
{"name": "Comic", |
|
"prompt": "comic {prompt} . graphic illustration, comic art, graphic novel art, vibrant, highly detailed"}, |
|
{"name": "Cinematic", |
|
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy"}, |
|
{"name": "3D Model", |
|
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting"}, |
|
{"name": "Anime", |
|
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed"}, |
|
{"name": "Digital Art", |
|
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed"}, |
|
{"name": "Photographic", |
|
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed"}, |
|
{"name": "Pixel art", "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics"}, |
|
{"name": "Fantasy art", |
|
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy"}, |
|
{"name": "Neonpunk", |
|
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional"}, |
|
{"name": "Manga", |
|
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style"}, |
|
] |
|
|
|
self.styles = {k["name"]: k["prompt"] for k in self.style_list} |
|
self.STYLE_NAMES = list(self.styles.keys()) |
|
self.DEFAULT_STYLE_NAME = "Fantasy art" |
|
self.MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
self.pipe = None |
|
self.zero_options = None |
|
def load_pipeline(self, zero_options): |
|
if self.pipe is None or zero_options != self.zero_options: |
|
self.pipe = Sketch2ImagePipeline() |
|
self.zero_options = zero_options |
|
|
|
@staticmethod |
|
def pil_image_to_data_uri(img, format="PNG"): |
|
buffered = BytesIO() |
|
img.save(buffered, format=format) |
|
img_str = base64.b64encode(buffered.getvalue()).decode() |
|
return f"data:image/{format.lower()};base64,{img_str}" |
|
|
|
def artwork(self, options, image, prompt, prompt_template, style_name, seed, val_r, faster, model_name, type_flag, prompt_quality): |
|
self.load_pipeline(zero_options=options) |
|
|
|
|
|
if type_flag: |
|
img = image["composite"] |
|
else: |
|
img = Image.fromarray(np.array(image["composite"])[:, :, -1]) |
|
|
|
img = img.convert("RGB") |
|
img = img.resize((512, 512)) |
|
|
|
image_t = F.to_tensor(img) > 0.5 |
|
c_t = image_t.unsqueeze(0).cuda().float() |
|
|
|
torch.manual_seed(seed) |
|
_, _, H, W = c_t.shape |
|
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device) |
|
|
|
with torch.no_grad(): |
|
output_image = self.pipe.generate(c_t, prompt, prompt_quality, prompt_template, r=val_r, noise_map=noise, half_model=faster, model_name=model_name) |
|
|
|
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5) |
|
|
|
|
|
return output_pil |