mpatel57's picture
function defination
a440782 verified
from __future__ import annotations
import pathlib
import gradio as gr
import torch
import os
import PIL
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import cv2
import matplotlib.pyplot as plt
from typing import Any
from transformers import (
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
CLIPImageProcessor,
CLIPTokenizer,
)
from transformers import CLIPTokenizer
from src.priors.lambda_prior_transformer import (
PriorTransformer,
) # original huggingface prior transformer without time conditioning
from src.pipelines.pipeline_kandinsky_subject_prior import KandinskyPriorPipeline
from diffusers import DiffusionPipeline
from PIL import Image
import random
import spaces
__device__ = "cuda"
__dtype__ = torch.float16
class Model:
def __init__(self):
self.device = __device__
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
projection_dim=1280,
torch_dtype=__dtype__,
)
.eval()
.requires_grad_(False)
).to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
)
prior = PriorTransformer.from_pretrained(
"ECLIPSE-Community/Lambda-ECLIPSE-Prior-v1.0",
torch_dtype=__dtype__,
)
self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-prior",
prior=prior,
torch_dtype=__dtype__,
).to(self.device)
self.pipe = DiffusionPipeline.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=__dtype__
).to(self.device)
def inference(self, raw_data, seed):
if seed is None:
seed = random.randint(0, 10000000)
generator = torch.Generator(device="cuda").manual_seed(seed)
image_emb, negative_image_emb = self.pipe_prior(
raw_data=raw_data,
generator=generator,
).to_tuple()
image = self.pipe(
image_embeds=image_emb,
negative_image_embeds=negative_image_emb,
num_inference_steps=50,
guidance_scale=7.5,
generator=generator,
).images[0]
return image
def run(
self,
image: dict[str, PIL.Image.Image],
keyword: str,
image2: dict[str, PIL.Image.Image],
keyword2: str,
text: str,
seed: int,
):
sub_imgs = [image["composite"]]
sun_keywords = [keyword]
if keyword2 and keyword2 != "no subject":
sun_keywords.append(keyword2)
if image2:
sub_imgs.append(image2["composite"])
raw_data = {
"prompt": text,
"subject_images": sub_imgs,
"subject_keywords": sun_keywords,
}
image = self.inference(raw_data, seed)
return image
model = Model()
@spaces.GPU
def generate_image(image,keyword,image2,keyword2,text,seed):
return model.run(image,keyword,image2,keyword2,text,seed)
def create_demo():
USAGE = """## To run the demo, you should:
1. Upload your image.
2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
3. Input a Keyword i.e. 'Dog'
4. For MultiSubject personalization,
4-1. Upload another image.
4-2. Input the Keyword i.e. 'Sunglasses'
3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
4. Click the Run button.
"""
with gr.Blocks() as demo:
gr.HTML(
"""<h1 style="text-align: center;"><b><i>λ-ECLIPSE</i>: Multi-Concept Personalized Text-to-Image Diffusion Models by Leveraging CLIP Latent Space</b></h1>
<h1 style='text-align: center;'><a href='https://eclipse-t2i.github.io/Lambda-ECLIPSE/'>Project Page</a> | <a href='#'>Paper</a> </h1>
<p style="text-align: center; color: red;">Please follow the instructions from here to run it locally: <a href="https://github.com/eclipse-t2i/lambda-eclipse-inference">GitHub Inference Code</a></p>
<a href="https://colab.research.google.com/drive/1VcqzXZmilntec3AsIyzCqlstEhX4Pa1o?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
"""
)
gr.Markdown(USAGE)
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"Upload your first masked subject image or mask out marginal space"
)
image = gr.ImageEditor(
label="Input",
type="pil",
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
)
keyword = gr.Text(
label="Keyword",
placeholder='e.g. "Dog", "Goofie"',
info="Keyword for first subject",
)
gr.Markdown(
"For Multi-Subject generation : Upload your second masked subject image or mask out marginal space"
)
image2 = gr.ImageEditor(
label="Input",
type="pil",
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
)
keyword2 = gr.Text(
label="Keyword",
placeholder='e.g. "Sunglasses", "Grand Canyon"',
info="Keyword for second subject",
)
prompt = gr.Text(
label="Prompt",
placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
info="Keep the keywords used previously in the prompt",
)
run_button = gr.Button("Run")
with gr.Column():
result = gr.Image(label="Result")
inputs = [
image,
keyword,
image2,
keyword2,
prompt,
]
gr.Examples(
examples=[
[
os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
"luffy",
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
"no subject",
"luffy holding a sword",
],
[
os.path.join(os.path.dirname(__file__), "./assets/luffy.jpg"),
"luffy",
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
"no subject",
"luffy in the living room",
],
[
os.path.join(os.path.dirname(__file__), "./assets/teapot.jpg"),
"teapot",
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
"no subject",
"teapot on a cobblestone street",
],
[
os.path.join(os.path.dirname(__file__), "./assets/trex.jpg"),
"trex",
os.path.join(os.path.dirname(__file__), "./assets/white.jpg"),
"no subject",
"trex near a river",
],
[
os.path.join(os.path.dirname(__file__), "./assets/cat.png"),
"cat",
os.path.join(
os.path.dirname(__file__), "./assets/blue_sunglasses.png"
),
"glasses",
"A cat wearing glasses on a snowy field",
],
[
os.path.join(os.path.dirname(__file__), "./assets/statue.jpg"),
"statue",
os.path.join(os.path.dirname(__file__), "./assets/toilet.jpg"),
"toilet",
"statue sitting on a toilet",
],
[
os.path.join(os.path.dirname(__file__), "./assets/teddy.jpg"),
"teddy",
os.path.join(os.path.dirname(__file__), "./assets/luffy_hat.jpg"),
"hat",
"a teddy wearing the hat at a beach",
],
[
os.path.join(os.path.dirname(__file__), "./assets/chair.jpg"),
"chair",
os.path.join(os.path.dirname(__file__), "./assets/table.jpg"),
"table",
"a chair and table in living room",
],
],
inputs=inputs,
fn=generate_image,
outputs=result,
)
run_button.click(fn=generate_image, inputs=inputs, outputs=result)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=20).launch()