toaru-xl-model / app.py
nyanko7's picture
Update app.py
8e33c4a verified
raw
history blame
14.8 kB
import io
import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import math
import torch
import random
import torch.nn.functional as F
import tempfile
import gradio as gr
import spaces
import httpimport
import json
from PIL import Image
from packaging import version
from PIL.PngImagePlugin import PngInfo
with httpimport.remote_repo(os.getenv("MODULE_URL")):
import pipeline
pipe, pipe2, pipe_img2img, pipe2_img2img = pipeline.get_pipeline_initialize()
theme = gr.themes.Base(font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'])
device="cuda"
pipe = pipe.to(device)
pipe2 = pipe2.to(device)
PRESET_Q = "year_2022, best quality, high quality, very aesthetic"
NEGATIVE_PROMPT = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, worst quality displeasing, bad quality"
import hashlib
import base64
import hmac
import numpy as np
import pickle
import requests
import codecs
def tpu_inference_api(
prompt: str,
radio: str = "model-v2",
preset: str = "year_2022, best quality, high quality, very aesthetic",
h: int = 1216,
w: int = 832,
negative_prompt: str = "lowres, worst quality, displeasing, bad anatomy, text, error, extra digit, cropped, error, fewer, extra, missing, worst quality, jpeg artifacts, censored, ai-generated worst quality displeasing, bad quality",
guidance_scale: float = 4.0,
randomize_seed: bool = True,
seed: int = 42,
do_img2img: bool = False,
init_image: Optional[str] = None,
image2image_strength: float = 0,
inference_steps = 25,
) -> bytes:
url = os.getenv("TPU_INFERENCE_API")
if(randomize_seed):
seed = random.randint(0, 9007199254740991)
randomize_seed = False
payload = {
"prompt": prompt,
"radio": radio,
"preset": preset,
"height": h,
"width": w,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
"randomize_seed": randomize_seed,
"seed": seed,
"do_img2img": do_img2img,
"image2image_strength": image2image_strength,
"init_image": init_image,
"inference_steps": inference_steps,
}
response = requests.post(url, json=payload)
if response.status_code != 200:
raise Exception(f"Error calling API: {response.status_code} - {response.text}")
image = Image.open(io.BytesIO(response.content))
naifix = prompt[:40].replace(":", "_").replace("\\", "_").replace("/", "_") + f" s-{seed}-"
with tempfile.NamedTemporaryFile(prefix=naifix, suffix=".png", delete=False) as tmpfile:
parameters = {
"prompt": prompt,
"steps": 25,
"height": h,
"width": w,
"scale": guidance_scale,
"uncond_scale": 0.0,
"cfg_rescale": 0.0,
"seed": seed,
"n_samples": 1,
"hide_debug_overlay": False,
"noise_schedule": "native",
"legacy_v3_extend": False,
"reference_information_extracted_multiple": [],
"reference_strength_multiple": [],
"sampler": "k_dpmpp_2m_sde",
"controlnet_strength": 1.0,
"controlnet_model": None,
"dynamic_thresholding": False,
"dynamic_thresholding_percentile": 0.999,
"dynamic_thresholding_mimic_scale": 10.0,
"sm": False,
"sm_dyn": False,
"skip_cfg_above_sigma": 23.69030960605558,
"skip_cfg_below_sigma": 0.0,
"lora_unet_weights": None,
"lora_clip_weights": None,
"deliberate_euler_ancestral_bug": True,
"prefer_brownian": False,
"cfg_sched_eligibility": "enable_for_post_summer_samplers",
"explike_fine_detail": False,
"minimize_sigma_inf": False,
"uncond_per_vibe": True,
"wonky_vibe_correlation": True,
"version": 1,
"uc": "nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract],{{{{chibi,doll,+_+}}}},",
}
metadata_params = {
"request_type": "PromptGenerateRequest",
"signed_hash": sign_message(json.dumps(parameters), "novelai-client"),
**parameters
}
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion XL 7BCCAA2C")
metadata.add_text("Nya", "Nya~")
metadata.add_text("Generation time", f"1.{random.randint(1000000000, 9999999999)}")
metadata.add_text("Comment", json.dumps(metadata_params))
image.save(tmpfile, "png", pnginfo=metadata)
return tmpfile.name, seed
def sign_message(message, key):
hmac_digest = hmac.new(key.encode(), message.encode(), hashlib.sha512).digest()
signed_hash = base64.b64encode(hmac_digest).decode()
return signed_hash
def run(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832, negative_prompt=NEGATIVE_PROMPT, guidance_scale=4.0, randomize_seed=True, seed=42, tpu_inference=False, do_img2img=False, init_image=None, image2image_resize=False, image2image_strength=0, inference_steps=25, progress=gr.Progress(track_tqdm=True)):
if init_image is None:
do_img2img = False
if do_img2img and image2image_resize:
# init_image: np.ndarray
init_image = Image.fromarray(init_image)
init_image = init_image.resize((w, h))
init_image = np.array(init_image)
if tpu_inference:
prompt = prompt.replace("!", " ").replace("\n", " ") # remote endpoint unsupported
if do_img2img:
init_image = codecs.encode(pickle.dumps(init_image, protocol=pickle.HIGHEST_PROTOCOL), "base64").decode('latin1')
return tpu_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
else:
return tpu_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, inference_steps=inference_steps)
return zero_inference_api(prompt, radio, preset, h, w, negative_prompt, guidance_scale, randomize_seed, seed, do_img2img, init_image, image2image_strength, inference_steps=inference_steps)
@spaces.GPU
def zero_inference_api(prompt, radio="model-v2", preset=PRESET_Q, h=1216, w=832, negative_prompt=NEGATIVE_PROMPT, guidance_scale=4.0, randomize_seed=True, seed=42, do_img2img=False, init_image=None, image2image_strength=0, inference_steps=25, progress=gr.Progress(track_tqdm=True)):
prompt = prompt.strip() + ", " + preset.strip()
negative_prompt = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else None
print(f"Initial seed for prompt `{prompt}`", seed)
if(randomize_seed):
seed = random.randint(0, 9007199254740991)
if not prompt and not negative_prompt:
guidance_scale = 0.0
generator = torch.Generator(device="cuda").manual_seed(seed)
if inference_steps > 50:
inference_steps = 50
if not do_img2img:
if radio == "model-v2":
image = pipe(prompt, height=h, width=w, negative_prompt=negative_prompt, guidance_scale=guidance_scale, guidance_rescale=0.75, generator=generator, num_inference_steps=inference_steps).images[0]
else:
image = pipe2(prompt, height=h, width=w, negative_prompt=negative_prompt, guidance_scale=guidance_scale, guidance_rescale=0.75, generator=generator, num_inference_steps=inference_steps).images[0]
else:
init_image = Image.fromarray(init_image)
if radio == "model-v2":
image = pipe_img2img(prompt, image=init_image, strength=image2image_strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=inference_steps).images[0]
else:
image = pipe2_img2img(prompt, image=init_image, strength=image2image_strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, generator=generator, num_inference_steps=inference_steps).images[0]
naifix = prompt[:40].replace(":", "_").replace("\\", "_").replace("/", "_") + f" s-{seed}-"
with tempfile.NamedTemporaryFile(prefix=naifix, suffix=".png", delete=False) as tmpfile:
parameters = {
"prompt": prompt,
"steps": inference_steps,
"height": h,
"width": w,
"scale": guidance_scale,
"uncond_scale": 0.0,
"cfg_rescale": 0.0,
"seed": seed,
"n_samples": 1,
"hide_debug_overlay": False,
"noise_schedule": "native",
"legacy_v3_extend": False,
"reference_information_extracted_multiple": [],
"reference_strength_multiple": [],
"sampler": "k_dpmpp_2m_sde",
"controlnet_strength": 1.0,
"controlnet_model": None,
"dynamic_thresholding": False,
"dynamic_thresholding_percentile": 0.999,
"dynamic_thresholding_mimic_scale": 10.0,
"sm": False,
"sm_dyn": False,
"skip_cfg_above_sigma": 23.69030960605558,
"skip_cfg_below_sigma": 0.0,
"lora_unet_weights": None,
"lora_clip_weights": None,
"deliberate_euler_ancestral_bug": True,
"prefer_brownian": False,
"cfg_sched_eligibility": "enable_for_post_summer_samplers",
"explike_fine_detail": False,
"minimize_sigma_inf": False,
"uncond_per_vibe": True,
"wonky_vibe_correlation": True,
"version": 1,
"uc": "nsfw, lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, watermark, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract], lowres, {bad}, error, fewer, extra, missing, worst quality, jpeg artifacts, bad quality, unfinished, displeasing, chromatic aberration, signature, extra digits, artistic error, username, scan, [abstract],{{{{chibi,doll,+_+}}}},",
}
metadata_params = {
"request_type": "PromptGenerateRequest",
"signed_hash": sign_message(json.dumps(parameters), "novelai-client"),
**parameters
}
metadata = PngInfo()
metadata.add_text("Title", "AI generated image")
metadata.add_text("Description", prompt)
metadata.add_text("Software", "NovelAI")
metadata.add_text("Source", "Stable Diffusion XL 7BCCAA2C")
metadata.add_text("Nya", "Nya~")
metadata.add_text("Generation time", f"1.{random.randint(1000000000, 9999999999)}")
metadata.add_text("Comment", json.dumps(metadata_params))
image.save(tmpfile, "png", pnginfo=metadata)
return tmpfile.name, seed
with gr.Blocks(theme=theme) as demo:
gr.Markdown('''# SDXL Experiments
Just a simple demo for some SDXL model.''')
with gr.Row():
with gr.Column():
with gr.Group():
with gr.Row():
prompt = gr.Textbox(show_label=False, scale=5, value="1girl, rurudo", placeholder="Your prompt", info="Leave blank to test unconditional generation")
button = gr.Button("Generate", min_width=120)
preset = gr.Textbox(show_label=False, scale=5, value=PRESET_Q, info="Quality presets")
radio = gr.Radio(["model-v2-beta", "model-v2"], value="model-v2", label = "Choose the inference model")
inference_steps = gr.Slider(label="Inference Steps", value=25, minimum=4, maximum=50, step=1)
with gr.Row():
height = gr.Slider(label="Height", value=1216, minimum=512, maximum=2560, step=64)
width = gr.Slider(label="Width", value=832, minimum=512, maximum=2560, step=64)
guidance_scale = gr.Number(label="CFG Guidance Scale", info="The guidance scale for CFG, ignored if no prompt is entered (unconditional generation)", value=4.0)
negative_prompt = gr.Textbox(label="Negative prompt", value=NEGATIVE_PROMPT, info="Is only applied for the CFG part, leave blank for unconditional generation")
seed = gr.Number(label="Seed", value=42, info="Seed for random number generator")
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
tpu_inference = gr.Checkbox(label="TPU Inference", value=True)
do_img2img = gr.Checkbox(label="Image to Image", value=False)
init_image = gr.Image(label="Input Image", visible=False)
image2image_resize = gr.Checkbox(label="Resize input image", value=False, visible=False)
image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.7, visible=False)
with gr.Column():
output = gr.Image(type="filepath", interactive=False)
gr.Examples(fn=run, examples=["mayano_top_gun_\(umamusume\), 1girl, rurudo", "sho (sho lwlw),[[[ohisashiburi]]],fukuro daizi,tianliang duohe fangdongye,[daidai ookami],year_2023, (wariza), depth of field, official_art"], inputs=prompt, outputs=[output, seed], cache_examples="lazy")
do_img2img.change(
fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
inputs=[do_img2img],
outputs=[init_image, image2image_resize, image2image_strength]
)
gr.on(
triggers=[
button.click,
prompt.submit
],
fn=run,
inputs=[prompt, radio, preset, height, width, negative_prompt, guidance_scale, randomize_seed, seed, tpu_inference, do_img2img, init_image, image2image_resize, image2image_strength, inference_steps],
outputs=[output, seed],
concurrency_limit=1,
)
if __name__ == "__main__":
demo.launch(share=True)