|
from __future__ import annotations |
|
import pathlib |
|
import gradio as gr |
|
import torch |
|
import os |
|
import PIL |
|
import torchvision.transforms as T |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
from typing import Any |
|
|
|
from transformers import ( |
|
CLIPTextModelWithProjection, |
|
CLIPVisionModelWithProjection, |
|
CLIPImageProcessor, |
|
CLIPTokenizer |
|
) |
|
|
|
from transformers import CLIPTokenizer |
|
from src.priors.lambda_prior_transformer import ( |
|
PriorTransformer, |
|
) |
|
from src.pipelines.pipeline_kandinsky_subject_prior import KandinskyPriorPipeline |
|
|
|
from diffusers import DiffusionPipeline |
|
from PIL import Image |
|
|
|
class Model: |
|
def __init__(self): |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.text_encoder = ( |
|
CLIPTextModelWithProjection.from_pretrained( |
|
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", |
|
projection_dim=1280, |
|
torch_dtype=torch.float16, |
|
) |
|
.eval() |
|
.requires_grad_(False) |
|
).to("cuda") |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained( |
|
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", |
|
) |
|
|
|
prior = PriorTransformer.from_pretrained( |
|
"ECLIPSE-Community/Lambda-ECLIPSE-Prior-v1.0", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
self.pipe_prior = KandinskyPriorPipeline.from_pretrained( |
|
"kandinsky-community/kandinsky-2-2-prior", |
|
prior=prior, |
|
torch_dtype=torch.float16, |
|
).to(self.device) |
|
|
|
self.pipe = DiffusionPipeline.from_pretrained( |
|
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16 |
|
).to(self.device) |
|
|
|
def inference(self, raw_data): |
|
image_emb, negative_image_emb = self.pipe_prior( |
|
raw_data=raw_data, |
|
).to_tuple() |
|
image = self.pipe( |
|
image_embeds=image_emb, |
|
negative_image_embeds=negative_image_emb, |
|
num_inference_steps=50, |
|
guidance_scale=4.0, |
|
).images[0] |
|
return image |
|
|
|
def process_data(self, |
|
image: PIL.Image.Image, |
|
keyword: str, |
|
image2: PIL.Image.Image, |
|
keyword2: str, |
|
text: str, |
|
) -> dict[str, Any]: |
|
print(f"keyword : {keyword}, keyword2 : {keyword2}, prompt : {text}") |
|
device = torch.device(self.device) |
|
data: dict[str, Any] = {} |
|
data['text'] = text |
|
|
|
txt = self.tokenizer( |
|
text, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt', |
|
) |
|
txt_items = {k: v.to(device) for k, v in txt.items()} |
|
new_feats = self.text_encoder(**txt_items) |
|
new_last_hidden_states = new_feats.last_hidden_state[0].cpu().numpy() |
|
|
|
plt.imshow(image) |
|
plt.title('image') |
|
plt.savefig('image_testt2.png') |
|
plt.show() |
|
|
|
mask_img = self.image_processor(image, return_tensors="pt").to("cuda") |
|
vision_feats = self.vision_encoder( |
|
**mask_img |
|
).image_embeds |
|
|
|
entity_tokens = self.tokenizer(keyword)["input_ids"][1:-1] |
|
for tid in entity_tokens: |
|
indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0] |
|
new_last_hidden_states[indices] = vision_feats[0].cpu().numpy() |
|
print(indices) |
|
|
|
if image2 is not None: |
|
mask_img2 = self.image_processor(image2, return_tensors="pt").to("cuda") |
|
vision_feats2 = self.vision_encoder( |
|
**mask_img2 |
|
).image_embeds |
|
if keyword2 is not None: |
|
entity_tokens = self.tokenizer(keyword2)["input_ids"][1:-1] |
|
for tid in entity_tokens: |
|
indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0] |
|
new_last_hidden_states[indices] = vision_feats2[0].cpu().numpy() |
|
print(indices) |
|
|
|
text_feats = { |
|
"prompt_embeds": new_feats.text_embeds.to("cuda"), |
|
"text_encoder_hidden_states": torch.tensor(new_last_hidden_states).unsqueeze(0).to("cuda"), |
|
"text_mask": txt_items["attention_mask"].to("cuda"), |
|
} |
|
return text_feats |
|
|
|
def run(self, |
|
image: dict[str, PIL.Image.Image], |
|
keyword: str, |
|
image2: dict[str, PIL.Image.Image], |
|
keyword2: str, |
|
text: str, |
|
): |
|
|
|
|
|
sub_imgs = [image["composite"]] |
|
if image2: |
|
sub_imgs.append(image2["composite"]) |
|
sun_keywords = [keyword] |
|
if keyword2: |
|
sun_keywords.append(keyword2) |
|
raw_data = { |
|
"prompt": text, |
|
"subject_images": sub_imgs, |
|
"subject_keywords": sun_keywords |
|
} |
|
image = self.inference(raw_data) |
|
return image |
|
|
|
def create_demo(): |
|
TITLE = '# [λ-Eclipse Demo](https://eclipse-t2i.github.io/Lambda-ECLIPSE/)' |
|
|
|
USAGE = '''To run the demo, you should: |
|
1. Upload your image. |
|
2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.** |
|
3. Input a Keyword i.e. 'Dog' |
|
4. For MultiSubject personalization, |
|
4-1. Upload another image. |
|
4-2. Input the Keyword i.e. 'Sunglasses' |
|
3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt. |
|
4. Click the Run button. |
|
''' |
|
|
|
model = Model() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(TITLE) |
|
gr.Markdown(USAGE) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
gr.Markdown( |
|
'Upload your first masked subject image or mask out marginal space') |
|
image = gr.ImageEditor(label='Input', type='pil', brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) |
|
keyword = gr.Text( |
|
label='Keyword', |
|
placeholder='e.g. "Dog", "Goofie"', |
|
info='Keyword for first subject') |
|
gr.Markdown( |
|
'For Multi-Subject generation : Upload your second masked subject image or mask out marginal space') |
|
image2 = gr.ImageEditor(label='Input', type='pil', brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed")) |
|
keyword2= gr.Text( |
|
label='Keyword', |
|
placeholder='e.g. "Sunglasses", "Grand Canyon"', |
|
info='Keyword for second subject') |
|
prompt = gr.Text( |
|
label='Prompt', |
|
placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"', |
|
info='Keep the keywords used previously in the prompt') |
|
|
|
run_button = gr.Button('Run') |
|
|
|
with gr.Column(): |
|
result = gr.Image(label='Result') |
|
|
|
inputs = [ |
|
image, |
|
keyword, |
|
image2, |
|
keyword2, |
|
prompt, |
|
] |
|
|
|
gr.Examples( |
|
examples=[[os.path.join(os.path.dirname(__file__), "./assets/cat.png"), "cat", os.path.join(os.path.dirname(__file__), "./assets/blue_sunglasses.png"), "glasses", "A cat wearing glasses on a snowy field"]], |
|
inputs = inputs, |
|
fn=model.run, |
|
outputs=result, |
|
) |
|
|
|
run_button.click(fn=model.run, inputs=inputs, outputs=result) |
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
demo = create_demo() |
|
demo.queue(api_open=False).launch(share=True) |