Tzktz's picture
Upload 7664 files
6fc683c verified
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
import copy
import string
import random
from typing import Optional
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
DiffusionPipeline,
LCMScheduler,
)
from tqdm import tqdm
from PIL import Image
from PIL import Image, ImageDraw, ImageFont
from fastchat.model import load_model, get_conversation_template
from transformers import AutoTokenizer, AutoModelForCausalLM
from cog import BasePredictor, Input, Path, BaseModel
alphabet = (
string.digits
+ string.ascii_lowercase
+ string.ascii_uppercase
+ string.punctuation
+ " "
) # len(aphabet) = 95
"""alphabet
0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~
"""
font_layout = ImageFont.truetype("./Arial.ttf", 16)
class ModelOutput(BaseModel):
output_images: list[Path]
composed_prompt: str
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
cache_dir = "model_cache"
local_files_only = True # set to True if the models are saved in cache_dir
self.m1_model_path = "JingyeChen22/textdiffuser2_layout_planner"
self.m1_tokenizer = AutoTokenizer.from_pretrained(
self.m1_model_path,
use_fast=False,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.m1_model = AutoModelForCausalLM.from_pretrained(
self.m1_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
cache_dir=cache_dir,
local_files_only=local_files_only,
).cuda()
self.text_encoder = (
CLIPTextModel.from_pretrained(
"JingyeChen22/textdiffuser2-full-ft",
subfolder="text_encoder",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.cuda()
.half()
)
self.tokenizer = CLIPTokenizer.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="tokenizer",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
#### additional tokens are introduced, including coordinate tokens and character tokens
print("***************")
print(f"tokenizer size: {len(self.tokenizer)}")
for i in range(520):
self.tokenizer.add_tokens(["l" + str(i)]) # left
self.tokenizer.add_tokens(["t" + str(i)]) # top
self.tokenizer.add_tokens(["r" + str(i)]) # width
self.tokenizer.add_tokens(["b" + str(i)]) # height
for c in alphabet:
self.tokenizer.add_tokens([f"[{c}]"])
print(f"new tokenizer size: {len(self.tokenizer)}")
print("***************")
self.vae = (
AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="vae",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.half()
.cuda()
)
self.unet = (
UNet2DConditionModel.from_pretrained(
"JingyeChen22/textdiffuser2-full-ft",
subfolder="unet",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.half()
.cuda()
)
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
self.scheduler = DDPMScheduler.from_pretrained(
"runwayml/stable-diffusion-v1-5",
subfolder="scheduler",
cache_dir=cache_dir,
local_files_only=local_files_only,
)
#### load lcm components
self.pipe = DiffusionPipeline.from_pretrained(
"lambdalabs/sd-pokemon-diffusers",
unet=copy.deepcopy(self.unet),
tokenizer=self.tokenizer,
text_encoder=copy.deepcopy(self.text_encoder),
torch_dtype=torch.float16,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
self.pipe.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
self.pipe.to(device="cuda")
def predict(
self,
prompt: str = Input(
description="Input Prompt. You can let language model automatically identify keywords, or provide them below.",
default="A beautiful city skyline stamp of Shanghai",
),
keywords: str = Input(
description="(Optional) Keywords. Should be seperated by / (e.g., keyword1/keyword2/...).",
default=None,
),
positive_prompt: str = Input(
description="(Optional) Positive prompt.",
default=", digital art, very detailed, fantasy, high definition, cinematic light, dnd, trending on artstation",
),
use_lcm: bool = Input(
description="Use Latent Consistent Model.", default=False
),
generate_natural_image: bool = Input(
description="If set to True, the text position and content info will not be incorporated.",
default=False,
),
num_images: int = Input(
description="Number of Output images.", default=1, ge=1, le=4
),
num_inference_steps: int = Input(
description="Number of denoising steps. You may decease the step to 4 when using LCM.",
ge=1,
le=50,
default=20,
),
guidance_scale: float = Input(
description="Scale for classifier-free guidance. The scale is set to 7.5 by default. When using LCM, guidance_scale is set to 1.",
ge=1,
le=20,
default=7.5,
),
temperature: float = Input(
description="Control the diversity of layout planner. Higher value indicates more diversity.",
ge=0.1,
le=2,
default=1.4,
),
) -> ModelOutput:
"""Run a single prediction on the model"""
if positive_prompt is not None and not len(positive_prompt.strip()) == 0:
prompt += positive_prompt
with torch.no_grad():
user_prompt = prompt
if generate_natural_image:
composed_prompt = user_prompt
prompt = self.tokenizer.encode(user_prompt)
else:
if keywords is None or len(keywords.strip()) == 0:
template = f"Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. All keywords are included in the caption. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {user_prompt}"
else:
keywords = keywords.split("/")
keywords = [i.strip() for i in keywords]
template = f"Given a prompt that will be used to generate an image, plan the layout of visual text for the image. The size of the image is 128x128. Therefore, all properties of the positions should not exceed 128, including the coordinates of top, left, right, and bottom. In addition, we also provide all keywords at random order for reference. You dont need to specify the details of font styles. At each line, the format should be keyword left, top, right, bottom. So let us begin. Prompt: {prompt}. Keywords: {str(keywords)}"
msg = template
conv = get_conversation_template(self.m1_model_path)
conv.append_message(conv.roles[0], msg)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
inputs = self.m1_tokenizer([prompt], return_token_type_ids=False)
inputs = {k: torch.tensor(v).to("cuda") for k, v in inputs.items()}
output_ids = self.m1_model.generate(
**inputs,
do_sample=True,
temperature=temperature,
repetition_penalty=1.0,
max_new_tokens=512,
)
if self.m1_model.config.is_encoder_decoder:
output_ids = output_ids[0]
else:
output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
outputs = self.m1_tokenizer.decode(
output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
print(f"[{conv.roles[0]}]\n{msg}")
print(f"[{conv.roles[1]}]\n{outputs}")
ocrs = outputs.split("\n")
current_ocr = ocrs
ocr_ids = []
print("user_prompt", user_prompt)
print("current_ocr", current_ocr)
for ocr in current_ocr:
ocr = ocr.strip()
if len(ocr) == 0 or "###" in ocr or ".com" in ocr:
continue
items = ocr.split()
pred = " ".join(items[:-1])
box = items[-1]
l, t, r, b = box.split(",")
l, t, r, b = int(l), int(t), int(r), int(b)
ocr_ids.extend(
["l" + str(l), "t" + str(t), "r" + str(r), "b" + str(b)]
)
char_list = list(pred)
char_list = [f"[{i}]" for i in char_list]
ocr_ids.extend(char_list)
ocr_ids.append(self.tokenizer.eos_token_id)
caption_ids = (
self.tokenizer(user_prompt, truncation=True, return_tensors="pt")
.input_ids[0]
.tolist()
)
try:
ocr_ids = self.tokenizer.encode(ocr_ids)
prompt = caption_ids + ocr_ids
except:
prompt = caption_ids
user_prompt = self.tokenizer.decode(prompt)
composed_prompt = self.tokenizer.decode(prompt)
prompt = prompt[:77]
while len(prompt) < 77:
prompt.append(self.tokenizer.pad_token_id)
if not use_lcm:
prompts_cond = prompt
prompts_nocond = [self.tokenizer.pad_token_id] * 77
prompts_cond = [prompts_cond] * num_images
prompts_nocond = [prompts_nocond] * num_images
prompts_cond = torch.Tensor(prompts_cond).long().cuda()
prompts_nocond = torch.Tensor(prompts_nocond).long().cuda()
scheduler = self.scheduler
scheduler.set_timesteps(num_inference_steps)
noise = torch.randn((num_images, 4, 64, 64)).to("cuda").half()
input = noise
encoder_hidden_states_cond = self.text_encoder(prompts_cond)[0].half()
encoder_hidden_states_nocond = self.text_encoder(prompts_nocond)[
0
].half()
for t in tqdm(scheduler.timesteps):
with torch.no_grad(): # classifier free guidance
noise_pred_cond = self.unet(
sample=input,
timestep=t,
encoder_hidden_states=encoder_hidden_states_cond[
:num_images
],
).sample # b, 4, 64, 64
noise_pred_uncond = self.unet(
sample=input,
timestep=t,
encoder_hidden_states=encoder_hidden_states_nocond[
:num_images
],
).sample # b, 4, 64, 64
noisy_residual = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
) # b, 4, 64, 64
input = scheduler.step(noisy_residual, t, input).prev_sample
del noise_pred_cond
del noise_pred_uncond
torch.cuda.empty_cache()
# decode
input = 1 / self.vae.config.scaling_factor * input
images = self.vae.decode(input, return_dict=False)[0]
width, height = 512, 512
results = []
new_image = Image.new("RGB", (2 * width, 2 * height))
for index, image in enumerate(images.cpu().float()):
image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray(
(image * 255).round().astype("uint8")
).convert("RGB")
results.append(image)
row = index // 2
col = index % 2
new_image.paste(image, (col * width, row * height))
else:
generator = torch.Generator(device=self.pipe.device).manual_seed(
random.randint(0, 1000)
)
results = self.pipe(
prompt=user_prompt,
generator=generator,
num_inference_steps=num_inference_steps,
guidance_scale=1,
num_images_per_prompt=num_images,
).images
torch.cuda.empty_cache()
output_paths = []
for i, sample in enumerate(results):
output_path = f"/tmp/out-{i}.png"
sample.save(output_path)
output_paths.append(Path(output_path))
return ModelOutput(
output_images=output_paths,
composed_prompt=composed_prompt,
)