None1145's picture
Update app.py
6e5841b verified
raw
history blame
5.64 kB
import gradio as gr
import numpy as np
import random
import time
from optimum.intel import OVStableDiffusionXLPipeline
import torch
from diffusers import EulerDiscreteScheduler
from io import BytesIO
from PIL import Image
import base64
model_id = "None1145/noobai-XL-Vpred-0.65s-openvino"
prev_height = 1216
prev_width = 832
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
def reload_model(new_model_id):
global pipe, model_id, prev_height, prev_width
model_id = new_model_id
try:
print(f"{model_id}...")
pipe = OVStableDiffusionXLPipeline.from_pretrained(model_id, compile=False)
if model_id == "None1145/noobai-XL-Vpred-0.65s-openvino":
scheduler_args = {"prediction_type": "v_prediction", "rescale_betas_zero_snr": True}
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, **scheduler_args)
pipe.reshape(batch_size=1, height=prev_height, width=prev_width, num_images_per_prompt=1)
pipe.compile()
print(f"{model_id}!!!")
return f"Model successfully loaded: {model_id}"
except Exception as e:
return f"Failed to load model: {str(e)}"
reload_model(model_id)
def infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
global prev_width, prev_height, pipe
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if prev_width != width or prev_height != height:
pipe.reshape(batch_size=1, height=height, width=width, num_images_per_prompt=1)
pipe.compile()
prev_width = width
prev_height = height
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
# # Save image as Base64
# buffered = BytesIO()
# image.save(buffered, format="PNG")
# base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
# return image, seed, f"data:image/png;base64,{base64_image}"
return image, seed
examples = ["murasame \(senren\), senren banka",]
with gr.Blocks() as img:
gr.Markdown("# OpenVINO Text to Image")
with gr.Column(elem_id="col-container"):
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=60,
step=1,
value=28,
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
# base64_view = gr.HTML(label="Base64 Image Preview", interactive=True)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=832,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1216,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.Markdown("### Model Reload")
with gr.Row():
new_model_id = gr.Text(label="New Model ID", placeholder="Enter model ID", value=model_id)
reload_button = gr.Button("Reload Model", variant="primary")
reload_status = gr.Text(label="Status", interactive=False)
reload_button.click(
fn=reload_model,
inputs=new_model_id,
outputs=reload_status,
)
run_button.click(
fn=infer,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
# outputs=[result, seed, base64_view],
outputs=[result, seed],
)
# # JavaScript logic to dynamically update HTML with Base64
# js_script = """
# <script>
# function updateBase64(html_id, base64_src) {
# document.getElementById(html_id).innerHTML = `<img src="${base64_src}" alt="Generated Image"/>`;
# }
# </script>
# """
# gr.HTML(js_script)
if __name__ == "__main__":
img.queue(max_size=10, concurrency_count=2).launch()