Spaces:
Runtime error
Runtime error
| import os | |
| import spaces | |
| os.environ['HF_HOME'] = os.path.join(os.path.dirname(__file__), 'hf_download') | |
| HF_TOKEN = os.environ['hf_token'] if 'hf_token' in os.environ else None | |
| import uuid | |
| import time | |
| import torch | |
| import numpy as np | |
| import gradio as gr | |
| import tempfile | |
| gradio_temp_dir = os.path.join(tempfile.gettempdir(), 'gradio') | |
| os.makedirs(gradio_temp_dir, exist_ok=True) | |
| from threading import Thread | |
| # Phi3 Hijack | |
| from transformers.models.phi3.modeling_phi3 import Phi3PreTrainedModel | |
| Phi3PreTrainedModel._supports_sdpa = True | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from lib_omost.pipeline import StableDiffusionXLOmostPipeline | |
| from chat_interface import ChatInterface | |
| from transformers.generation.stopping_criteria import StoppingCriteriaList | |
| import lib_omost.canvas as omost_canvas | |
| # SDXL | |
| sdxl_name = 'SG161222/RealVisXL_V4.0' | |
| # sdxl_name = 'stabilityai/stable-diffusion-xl-base-1.0' | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| sdxl_name, subfolder="tokenizer") | |
| tokenizer_2 = CLIPTokenizer.from_pretrained( | |
| sdxl_name, subfolder="tokenizer_2") | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| sdxl_name, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16", device_map="auto") | |
| text_encoder_2 = CLIPTextModel.from_pretrained( | |
| sdxl_name, subfolder="text_encoder_2", torch_dtype=torch.float16, variant="fp16", device_map="auto") | |
| vae = AutoencoderKL.from_pretrained( | |
| sdxl_name, subfolder="vae", torch_dtype=torch.bfloat16, variant="fp16", device_map="auto") # bfloat16 vae | |
| unet = UNet2DConditionModel.from_pretrained( | |
| sdxl_name, subfolder="unet", torch_dtype=torch.float16, variant="fp16", device_map="auto") | |
| unet.set_attn_processor(AttnProcessor2_0()) | |
| vae.set_attn_processor(AttnProcessor2_0()) | |
| pipeline = StableDiffusionXLOmostPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| text_encoder_2=text_encoder_2, | |
| tokenizer_2=tokenizer_2, | |
| unet=unet, | |
| scheduler=None, # We completely give up diffusers sampling system and use A1111's method | |
| ) | |
| # LLM | |
| # model_name = 'lllyasviel/omost-phi-3-mini-128k' | |
| llm_name = 'lllyasviel/omost-llama-3-8b' | |
| # model_name = 'lllyasviel/omost-dolphin-2.9-llama3-8b' | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| llm_name, | |
| torch_dtype="auto", | |
| token=HF_TOKEN, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| llm_tokenizer = AutoTokenizer.from_pretrained( | |
| llm_name, | |
| token=HF_TOKEN | |
| ) | |
| def pytorch2numpy(imgs): | |
| results = [] | |
| for x in imgs: | |
| y = x.movedim(0, -1) | |
| y = y * 127.5 + 127.5 | |
| y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) | |
| results.append(y) | |
| return results | |
| def numpy2pytorch(imgs): | |
| h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 | |
| h = h.movedim(-1, 1) | |
| return h | |
| def resize_without_crop(image, target_width, target_height): | |
| pil_image = Image.fromarray(image) | |
| resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) | |
| return np.array(resized_image) | |
| def chat_fn(message: str, history: list, seed:int, temperature: float, top_p: float, max_new_tokens: int) -> str: | |
| print('Chat begin:', message) | |
| time_stamp = time.time() | |
| np.random.seed(int(seed)) | |
| torch.manual_seed(int(seed)) | |
| conversation = [{"role": "system", "content": omost_canvas.system_prompt}] | |
| for user, assistant in history: | |
| if isinstance(user, str) and isinstance(assistant, str): | |
| if len(user) > 0 and len(assistant) > 0: | |
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = llm_tokenizer.apply_chat_template( | |
| conversation, return_tensors="pt", add_generation_prompt=True).to(llm_model.device) | |
| streamer = TextIteratorStreamer(llm_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| def interactive_stopping_criteria(*args, **kwargs) -> bool: | |
| if getattr(streamer, 'user_interrupted', False): | |
| print('User stopped generation:', message) | |
| return True | |
| else: | |
| return False | |
| stopping_criteria = StoppingCriteriaList([interactive_stopping_criteria]) | |
| def interrupter(): | |
| streamer.user_interrupted = True | |
| return | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| stopping_criteria=stopping_criteria, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| if temperature == 0: | |
| generate_kwargs['do_sample'] = False | |
| Thread(target=llm_model.generate, kwargs=generate_kwargs).start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| # print(outputs) | |
| yield "".join(outputs), None | |
| print(f'Chat end at {time.time() - time_stamp:.2f} seconds:', message) | |
| return | |
| def post_chat(history): | |
| canvas_outputs = None | |
| try: | |
| if history: | |
| history = [(user, assistant) for user, assistant in history if isinstance(user, str) and isinstance(assistant, str)] | |
| last_assistant = history[-1][1] if len(history) > 0 else None | |
| canvas = omost_canvas.Canvas.from_bot_response(last_assistant) | |
| canvas_outputs = canvas.process() | |
| except Exception as e: | |
| print('Last assistant response is not valid canvas:', e) | |
| return canvas_outputs, gr.update(visible=canvas_outputs is not None), gr.update(interactive=len(history) > 0) | |
| def diffusion_fn(chatbot, canvas_outputs, num_samples, seed, image_width, image_height, | |
| highres_scale, steps, cfg, highres_steps, highres_denoise, negative_prompt): | |
| use_initial_latent = False | |
| eps = 0.05 | |
| image_width, image_height = int(image_width // 64) * 64, int(image_height // 64) * 64 | |
| rng = torch.Generator(unet.device).manual_seed(seed) | |
| positive_cond, positive_pooler, negative_cond, negative_pooler = pipeline.all_conds_from_canvas(canvas_outputs, negative_prompt) | |
| if use_initial_latent: | |
| initial_latent = torch.from_numpy(canvas_outputs['initial_latent'])[None].movedim(-1, 1) / 127.5 - 1.0 | |
| initial_latent_blur = 40 | |
| initial_latent = torch.nn.functional.avg_pool2d( | |
| torch.nn.functional.pad(initial_latent, (initial_latent_blur,) * 4, mode='reflect'), | |
| kernel_size=(initial_latent_blur * 2 + 1,) * 2, stride=(1, 1)) | |
| initial_latent = torch.nn.functional.interpolate(initial_latent, (image_height, image_width)) | |
| initial_latent = initial_latent.to(dtype=vae.dtype, device=vae.device) | |
| initial_latent = vae.encode(initial_latent).latent_dist.mode() * vae.config.scaling_factor | |
| else: | |
| initial_latent = torch.zeros(size=(num_samples, 4, image_height // 8, image_width // 8), dtype=torch.float32) | |
| initial_latent = initial_latent.to(dtype=unet.dtype, device=unet.device) | |
| latents = pipeline( | |
| initial_latent=initial_latent, | |
| strength=1.0, | |
| num_inference_steps=int(steps), | |
| batch_size=num_samples, | |
| prompt_embeds=positive_cond, | |
| negative_prompt_embeds=negative_cond, | |
| pooled_prompt_embeds=positive_pooler, | |
| negative_pooled_prompt_embeds=negative_pooler, | |
| generator=rng, | |
| guidance_scale=float(cfg), | |
| ).images | |
| latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor | |
| pixels = vae.decode(latents).sample | |
| B, C, H, W = pixels.shape | |
| pixels = pytorch2numpy(pixels) | |
| if highres_scale > 1.0 + eps: | |
| pixels = [ | |
| resize_without_crop( | |
| image=p, | |
| target_width=int(round(W * highres_scale / 64.0) * 64), | |
| target_height=int(round(H * highres_scale / 64.0) * 64) | |
| ) for p in pixels | |
| ] | |
| pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype) | |
| latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor | |
| latents = latents.to(device=unet.device, dtype=unet.dtype) | |
| latents = pipeline( | |
| initial_latent=latents, | |
| strength=highres_denoise, | |
| num_inference_steps=highres_steps, | |
| batch_size=num_samples, | |
| prompt_embeds=positive_cond, | |
| negative_prompt_embeds=negative_cond, | |
| pooled_prompt_embeds=positive_pooler, | |
| negative_pooled_prompt_embeds=negative_pooler, | |
| generator=rng, | |
| guidance_scale=float(cfg), | |
| ).images | |
| latents = latents.to(dtype=vae.dtype, device=vae.device) / vae.config.scaling_factor | |
| pixels = vae.decode(latents).sample | |
| pixels = pytorch2numpy(pixels) | |
| for i in range(len(pixels)): | |
| unique_hex = uuid.uuid4().hex | |
| image_path = os.path.join(gradio_temp_dir, f"{unique_hex}_{i}.png") | |
| image = Image.fromarray(pixels[i]) | |
| image.save(image_path) | |
| chatbot = chatbot + [(None, (image_path, 'image'))] | |
| return chatbot | |
| css = ''' | |
| code {white-space: pre-wrap !important;} | |
| .gradio-container {max-width: none !important;} | |
| .outer_parent {flex: 1;} | |
| .inner_parent {flex: 1;} | |
| footer {display: none !important; visibility: hidden !important;} | |
| .translucent {display: none !important; visibility: hidden !important;} | |
| ''' | |
| from gradio.themes.utils import colors | |
| with gr.Blocks( | |
| fill_height=True, css=css, | |
| theme=gr.themes.Default(primary_hue=colors.blue, secondary_hue=colors.cyan, neutral_hue=colors.gray) | |
| ) as demo: | |
| with gr.Row(elem_classes='outer_parent'): | |
| with gr.Column(scale=25): | |
| with gr.Row(): | |
| clear_btn = gr.Button("➕ New Chat", variant="secondary", size="sm", min_width=60) | |
| retry_btn = gr.Button("Retry", variant="secondary", size="sm", min_width=60, visible=False) | |
| undo_btn = gr.Button("✏️️ Edit Last Input", variant="secondary", size="sm", min_width=60, interactive=False) | |
| seed = gr.Number(label="Random Seed", value=123456, precision=0) | |
| with gr.Accordion(open=True, label='Language Model'): | |
| with gr.Group(): | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.01, | |
| value=0.6, | |
| label="Temperature") | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.9, | |
| label="Top P") | |
| max_new_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=4096, | |
| step=1, | |
| value=4096, | |
| label="Max New Tokens") | |
| with gr.Accordion(open=True, label='Image Diffusion Model'): | |
| with gr.Group(): | |
| with gr.Row(): | |
| image_width = gr.Slider(label="Image Width", minimum=256, maximum=2048, value=896, step=64) | |
| image_height = gr.Slider(label="Image Height", minimum=256, maximum=2048, value=1152, step=64) | |
| with gr.Row(): | |
| num_samples = gr.Slider(label="Image Number", minimum=1, maximum=12, value=1, step=1) | |
| steps = gr.Slider(label="Sampling Steps", minimum=1, maximum=100, value=25, step=1) | |
| with gr.Accordion(open=False, label='Advanced'): | |
| cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=5.0, step=0.01) | |
| highres_scale = gr.Slider(label="HR-fix Scale (\"1\" is disabled)", minimum=1.0, maximum=2.0, value=1.0, step=0.01) | |
| highres_steps = gr.Slider(label="Highres Fix Steps", minimum=1, maximum=100, value=20, step=1) | |
| highres_denoise = gr.Slider(label="Highres Fix Denoise", minimum=0.1, maximum=1.0, value=0.4, step=0.01) | |
| n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') | |
| render_button = gr.Button("Render the Image!", size='lg', variant="primary", visible=False) | |
| examples = gr.Dataset( | |
| samples=[ | |
| ['generate an image of the fierce battle of warriors and the dragon'], | |
| ['change the dragon to a dinosaur'] | |
| ], | |
| components=[gr.Textbox(visible=False)], | |
| label='Quick Prompts' | |
| ) | |
| with gr.Row(): | |
| gr.Markdown("Omost: converting LLM's coding capability to image compositing capability.") | |
| with gr.Row(): | |
| gr.Markdown("Local version (8GB VRAM): https://github.com/lllyasviel/Omost") | |
| # with gr.Row(): | |
| # gr.Markdown("Hint: You can [duplicate this space](https://huggingface.co/spaces/lllyasviel/Omost?duplicate=true) to your private account to bypass the waiting queue.") | |
| with gr.Column(scale=75, elem_classes='inner_parent'): | |
| canvas_state = gr.State(None) | |
| chatbot = gr.Chatbot(label='Omost', scale=1, show_copy_button=True, layout="panel", render=False) | |
| chatInterface = ChatInterface( | |
| fn=chat_fn, | |
| post_fn=post_chat, | |
| post_fn_kwargs=dict(inputs=[chatbot], outputs=[canvas_state, render_button, undo_btn]), | |
| pre_fn=lambda: gr.update(visible=False), | |
| pre_fn_kwargs=dict(outputs=[render_button]), | |
| chatbot=chatbot, | |
| retry_btn=retry_btn, | |
| undo_btn=undo_btn, | |
| clear_btn=clear_btn, | |
| additional_inputs=[seed, temperature, top_p, max_new_tokens], | |
| examples=examples, | |
| show_stop_button=False | |
| ) | |
| render_button.click( | |
| fn=diffusion_fn, inputs=[ | |
| chatInterface.chatbot, canvas_state, | |
| num_samples, seed, image_width, image_height, highres_scale, | |
| steps, cfg, highres_steps, highres_denoise, n_prompt | |
| ], outputs=[chatInterface.chatbot]).then( | |
| fn=lambda x: x, inputs=[ | |
| chatInterface.chatbot | |
| ], outputs=[chatInterface.chatbot_state]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |