|
import os |
|
import spaces |
|
|
|
os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') |
|
HF_TOKEN = os.environ['hf_token'] if 'hf_token' in os.environ else None |
|
|
|
import uuid |
|
import time |
|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import tempfile |
|
|
|
gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') |
|
os.makedirs(gradio_temp_dir, exist_ok=True) |
|
|
|
from threading import Thread |
|
|
|
|
|
from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel |
|
|
|
Phi3PreTrainedModel._supports_sdpa = True |
|
|
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from diffusers import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.models.attention_processor import AttnProcessor2_0 |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from lib_omost.pipeline import StableDiffusionXLOmostPipeline |
|
from chat_interface import ChatInterface |
|
from transformers.generation.stopping_criteria import StoppingCriteriaList |
|
|
|
import lib_omost.canvas as omost_canvas |
|
|
|
|
|
|
|
|
|
sdxl_name = 'SG161222/RealVisXL_V4.0' |
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
sdxl_name, subfolder="tokenizer") |
|
tokenizer_2 = CLIPTokenizer.from_pretrained( |
|
sdxl_name, subfolder="tokenizer_2") |
|
text_encoder = CLIPTextModel.from_pretrained( |
|
sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto") |
|
text_encoder_2 = CLIPTextModel.from_pretrained( |
|
sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto") |
|
vae = AutoencoderKL.from_pretrained( |
|
sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") |
|
unet = UNet2DConditionModel.from_pretrained( |
|
sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto") |
|
|
|
unet.set_attn_processor(AttnProcessor2_0()) |
|
vae.set_attn_processor(AttnProcessor2_0()) |
|
|
|
pipeline = StableDiffusionXLOmostPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
text_encoder_2=text_encoder_2, |
|
tokenizer_2=tokenizer_2, |
|
unet=unet, |
|
scheduler=None, |
|
) |
|
|
|
|
|
|
|
|
|
llm_name = 'lllyasviel/omost-llama-3-8b' |
|
|
|
|
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
llm_name, |
|
torch_dtype="auto", |
|
token=HF_TOKEN, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
llm_tokenizer = AutoTokenizer.from_pretrained( |
|
llm_name, |
|
token=HF_TOKEN |
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
def pytorch2numpy(imgs): |
|
results = [] |
|
for x in imgs: |
|
y = x.movedim(0, -1) |
|
y = y * 127.5 + 127.5 |
|
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) |
|
results.append(y) |
|
return results |
|
|
|
|
|
@torch.inference_mode() |
|
def numpy2pytorch(imgs): |
|
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 |
|
h = h.movedim(-1, 1) |
|
return h |
|
|
|
|
|
def resize_without_crop(image, target_width, target_height): |
|
pil_image = Image.fromarray(image) |
|
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) |
|
return np.array(resized_image) |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
@torch.inference_mode() |
|
def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str: |
|
print('Chat begin:', message) |
|
time_stamp = time.time() |
|
|
|
np.random.seed(int(seed)) |
|
torch.manual_seed(int(seed)) |
|
|
|
conversation = [{"role": "system", "content": omost_canvas.system_prompt}] |
|
|
|
for user, assistant in history: |
|
if user is None or assistant is None: |
|
continue |
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
input_ids = llm_tokenizer.apply_chat_template( |
|
conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device) |
|
|
|
streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
def interactive_stopping_criteria(*args, **kwargs) -> bool: |
|
if getattr(streamer, 'user_interrupted', False): |
|
print('User stopped generation:', message) |
|
return True |
|
else: |
|
return False |
|
|
|
stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria]) |
|
|
|
def interrupter(): |
|
streamer.user_interrupted = True |
|
return |
|
|
|
generate_kwargs = dict( |
|
input_ids=input_ids, |
|
streamer=streamer, |
|
stopping_criteria=stopping_criteria, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
|
|
if temperature == 0: |
|
generate_kwargs['do_sample'] = False |
|
|
|
Thread(target=llm_model.generate, kwargs=generate_kwargs).start() |
|
|
|
outputs = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
|
|
yield "".join(outputs), None |
|
|
|
print(f'Chat end at {time.time() - time_stamp:.2f} seconds:', message) |
|
return |
|
|
|
|
|
@torch.inference_mode() |
|
def post_chat(history): |
|
canvas_outputs = None |
|
|
|
try: |
|
if history: |
|
history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)] |
|
last_assistant = history[-1][1] if len(history) > 0 else None |
|
canvas = omost_canvas.Canvas.from_bot_response(last_assistant) |
|
canvas_outputs = canvas.process() |
|
except Exception as e: |
|
print('Last assistant response is not valid canvas:', e) |
|
|
|
return canvas_outputs, gr.update(visible=canvas_outputs is not None), gr.update(interactive=len(history) > 0) |
|
|
|
|
|
@spaces.GPU |
|
@torch.inference_mode() |
|
def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height, |
|
highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt): |
|
|
|
use_initial_latent = False |
|
eps = 0.05 |
|
|
|
image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 |
|
|
|
rng = torch.Generator(unet.device).manual_seed(seed) |
|
|
|
positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt) |
|
|
|
if use_initial_latent: |
|
initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0 |
|
initial_latent_blur = 40 |
|
initial_latent = torch.nn.functional.avg_pool2d( |
|
torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'), |
|
kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1)) |
|
initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width)) |
|
initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device) |
|
initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor |
|
else: |
|
initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32) |
|
|
|
initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device) |
|
|
|
latents = pipeline( |
|
initial_latent=initial_latent, |
|
strength=1.0, |
|
num_inference_steps=int(steps), |
|
batch_size=num_samples, |
|
prompt_embeds=positive_cond, |
|
negative_prompt_embeds=negative_cond, |
|
pooled_prompt_embeds=positive_pooler, |
|
negative_pooled_prompt_embeds=negative_pooler, |
|
generator=rng, |
|
guidance_scale=float(cfg), |
|
).images |
|
|
|
latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor |
|
pixels = vae.decode(latents).sample |
|
B, C, H, W = pixels.shape |
|
pixels = pytorch2numpy(pixels) |
|
|
|
if highres_scale > 1.0 + eps: |
|
pixels = [ |
|
resize_without_crop( |
|
image=p, |
|
target_width=int(round(W * highres_scale / 64.0) * 64), |
|
target_height=int(round(H * highres_scale / 64.0) * 64) |
|
) for p in pixels |
|
] |
|
|
|
pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype) |
|
latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor |
|
|
|
latents = latents.to(device=unet.device, dtype=unet.dtype) |
|
|
|
latents = pipeline( |
|
initial_latent=latents, |
|
strength=highres_denoise, |
|
num_inference_steps=highres_steps, |
|
batch_size=num_samples, |
|
prompt_embeds=positive_cond, |
|
negative_prompt_embeds=negative_cond, |
|
pooled_prompt_embeds=positive_pooler, |
|
negative_pooled_prompt_embeds=negative_pooler, |
|
generator=rng, |
|
guidance_scale=float(cfg), |
|
).images |
|
|
|
latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor |
|
pixels = vae.decode(latents).sample |
|
pixels = pytorch2numpy(pixels) |
|
|
|
for i in range(len(pixels)): |
|
unique_hex = uuid.uuid4().hex |
|
image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png") |
|
image = Image.fromarray(pixels[i]) |
|
image.save(image_path) |
|
chatbot = chatbot + [(None, (image_path, 'image'))] |
|
|
|
return chatbot |
|
|
|
|
|
css = ''' |
|
code {white-space: pre-wrap !important;} |
|
.gradio-container {max-width: none !important;} |
|
.outer_parent {flex: 1;} |
|
.inner_parent {flex: 1;} |
|
footer {display: none !important; visibility: hidden !important;} |
|
.translucent {display: none !important; visibility: hidden !important;} |
|
''' |
|
|
|
from gradio.themes.utils import colors |
|
|
|
with gr.Blocks( |
|
fill_height=True, css=css, |
|
theme=gr.themes.Default(primary_hue=colors.blue, secondary_hue=colors.cyan, neutral_hue=colors.gray) |
|
) as demo: |
|
with gr.Row(elem_classes='outer_parent'): |
|
with gr.Column(scale=25): |
|
with gr.Row(): |
|
clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60) |
|
retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False) |
|
undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False) |
|
|
|
seed = gr.Number(label="Random Seed", value=123456, precision=0) |
|
|
|
with gr.Accordion(open=True, label='Language Model'): |
|
with gr.Group(): |
|
with gr.Row(): |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=2.0, |
|
step=0.01, |
|
value=0.6, |
|
label="Temperature") |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.9, |
|
label="Top P") |
|
max_new_tokens = gr.Slider( |
|
minimum=128, |
|
maximum=4096, |
|
step=1, |
|
value=4096, |
|
label="Max New Tokens") |
|
with gr.Accordion(open=True, label='Image Diffusion Model'): |
|
with gr.Group(): |
|
with gr.Row(): |
|
image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64) |
|
image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64) |
|
|
|
with gr.Row(): |
|
num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1) |
|
steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1) |
|
|
|
with gr.Accordion(open=False, label='Advanced'): |
|
cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01) |
|
highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01) |
|
highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1) |
|
highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01) |
|
n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') |
|
|
|
render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False) |
|
|
|
examples = gr.Dataset( |
|
samples=[ |
|
['generate an image of the fierce battle of warriors and a dragon'], |
|
['change the dragon to a dinosaur'] |
|
], |
|
components=[gr.Textbox(visible=False)], |
|
label='Quick Prompts' |
|
) |
|
|
|
with gr.Row(): |
|
gr.Markdown("Omost: converting LLM's coding capability to image compositing capability.") |
|
with gr.Row(): |
|
gr.Markdown("Local version (8GB VRAM): https://github.com/lllyasviel/Omost") |
|
with gr.Row(): |
|
gr.Markdown("Note that you can only occupy HuggingFace zero GPU 170 seconds every 5 minutes. If connection error out, wait 5 minites and try again.") |
|
|
|
with gr.Column(scale=75, elem_classes='inner_parent'): |
|
canvas_state = gr.State(None) |
|
chatbot = gr.Chatbot(label='Omost', scale=1, show_copy_button=True, layout="panel", render=False) |
|
chatInterface = ChatInterface( |
|
fn=chat_fn, |
|
post_fn=post_chat, |
|
post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button, undo_btn]), |
|
pre_fn=lambda: gr.update(visible=False), |
|
pre_fn_kwargs=dict(outputs=[render_button]), |
|
chatbot=chatbot, |
|
retry_btn=retry_btn, |
|
undo_btn=undo_btn, |
|
clear_btn=clear_btn, |
|
additional_inputs=[seed, temperature, top_p, max_new_tokens], |
|
examples=examples, |
|
show_stop_button=False |
|
) |
|
|
|
render_button.click( |
|
fn=diffusion_fn, inputs=[ |
|
chatInterface.chatbot, canvas_state, |
|
num_samples, seed, image_width, image_height, highres_scale, |
|
steps, cfg, highres_steps, highres_denoise, n_prompt |
|
], outputs=[chatInterface.chatbot]).then( |
|
fn=lambda x: x, inputs=[ |
|
chatInterface.chatbot |
|
], outputs=[chatInterface.chatbot_state]) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |
|
|