import sys from pathlib import Path from typing import List, Optional import gradio as gr import torch from PIL import Image from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from huggingface_hub import snapshot_download from transformers import CLIPTokenizer import constants from checkpoint_handler import CheckpointHandler from models.neti_clip_text_encoder import NeTICLIPTextModel from models.xti_attention_processor import XTIAttenProc from prompt_manager import PromptManager from scripts.inference import run_inference sys.path.append(".") sys.path.append("..") DESCRIPTION = ''' # A Neural Space-Time Representation for Text-to-Image Personalization

This is a demo for our paper: ''A Neural Space-Time Representation for Text-to-Image Personalization''.
Project page and code is available here.
We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and the U-Net layers. This space-time representation is learned implicitly via a small mapping network.
Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and random seed.
You can also choose different truncation values to play with the reconstruction vs. editability of the concept.

''' CONCEPT_TO_PLACEHOLDER = { 'barn': '', 'cat': '', 'clock': '', 'colorful_teapot': '', 'dangling_child': '', 'dog': '', 'elephant': '', 'fat_stone_bird': '', 'headless_statue': '', 'lecun': '', 'maeve': '', 'metal_bird': '', 'mugs_skulls': '', 'rainbow_cat': '', 'red_bowl': '', 'teddybear': '', 'tortoise_plushy': '', 'wooden_pot': '' } MODELS_PATH = Path('./trained_models') MODELS_PATH.mkdir(parents=True, exist_ok=True) def load_stable_diffusion_model(pretrained_model_name_or_path: str, num_denoising_steps: int = 50, torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline: tokenizer = CLIPTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = NeTICLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, ) pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch_dtype, text_encoder=text_encoder, tokenizer=tokenizer ).to("cuda") pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device) pipeline.unet.set_attn_processor(XTIAttenProc()) return pipeline def get_possible_concepts() -> List[str]: objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()] return [x.name for x in objects] def load_sd_and_all_tokens(): mappers = {} pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4") print("Downloading all models from HF Hub...") snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models') print("Done.") concepts = get_possible_concepts() for concept in concepts: print(f"Loading model for concept: {concept}") learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin" mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt" train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path) placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip( learned_embeds_path=learned_embeds_path, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer ) mappers[concept] = { "mapper": mapper, "placeholder_token": placeholder_token, "placeholder_token_id": placeholder_token_id } return mappers, pipeline mappers, pipeline = load_sd_and_all_tokens() def main_pipeline(concept_name: str, prompt_input: str, seed: int, use_truncation: bool = False, truncation_idx: Optional[int] = None) -> Image.Image: pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"]) placeholder_token = mappers[concept_name]["placeholder_token"] placeholder_token_id = mappers[concept_name]["placeholder_token_id"] prompt_manager = PromptManager(tokenizer=pipeline.tokenizer, text_encoder=pipeline.text_encoder, timesteps=pipeline.scheduler.timesteps, unet_layers=constants.UNET_LAYERS, placeholder_token=placeholder_token, placeholder_token_id=placeholder_token_id, torch_dtype=torch.float16) image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]), pipeline=pipeline, prompt_manager=prompt_manager, seeds=[int(seed)], num_images_per_prompt=1, truncation_idx=truncation_idx if use_truncation else None) return [image] with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) gr.HTML('''Duplicate Space''') with gr.Row(): with gr.Column(): concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept", info="Choose your concept") prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. " "Please use * to specify the concept.") random_seed = gr.Number(value=42, label="Random seed", precision=0) use_truncation = gr.Checkbox(label="Use inference-time dropout", info="Whether to use our dropout technique when computing the concept " "embeddings.") truncation_idx = gr.Slider(8, 128, label="Truncation index", info="If using truncation, which index to truncate from. Lower numbers tend to " "result in more editable images, but at the cost of reconstruction.") run_button = gr.Button('Generate') with gr.Column(): result = gr.Gallery(label='Result') inputs = [concept, prompt, random_seed, use_truncation, truncation_idx] outputs = [result] run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs) with gr.Row(): examples = [ ["maeve", "A photo of * swimming in the ocean", 5196, True, 16], ["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8], ["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32], ["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8], ["metal_bird", "* in a comic book", 1028, True, 24], ["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True, 64], ] gr.Examples(examples=examples, inputs=[concept, prompt, random_seed, use_truncation, truncation_idx], outputs=[result], fn=main_pipeline, cache_examples=True) demo.queue(max_size=50).launch(share=False)