|
""" |
|
AnyText: Multilingual Visual Text Generation And Editing |
|
Paper: https://arxiv.org/abs/2311.03054 |
|
Code: https://github.com/tyxsspa/AnyText |
|
Copyright (c) Alibaba, Inc. and its affiliates. |
|
""" |
|
import os |
|
from pathlib import Path |
|
|
|
from iopaint.model.utils import set_seed |
|
from safetensors.torch import load_file |
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
import torch |
|
import re |
|
import numpy as np |
|
import cv2 |
|
import einops |
|
from PIL import ImageFont |
|
from iopaint.model.anytext.cldm.model import create_model, load_state_dict |
|
from iopaint.model.anytext.cldm.ddim_hacked import DDIMSampler |
|
from iopaint.model.anytext.utils import ( |
|
check_channels, |
|
draw_glyph, |
|
draw_glyph2, |
|
) |
|
|
|
|
|
BBOX_MAX_NUM = 8 |
|
PLACE_HOLDER = "*" |
|
max_chars = 20 |
|
|
|
ANYTEXT_CFG = os.path.join( |
|
os.path.dirname(os.path.abspath(__file__)), "anytext_sd15.yaml" |
|
) |
|
|
|
|
|
def check_limits(tensor): |
|
float16_min = torch.finfo(torch.float16).min |
|
float16_max = torch.finfo(torch.float16).max |
|
|
|
|
|
is_below_min = (tensor < float16_min).any() |
|
is_above_max = (tensor > float16_max).any() |
|
|
|
return is_below_min or is_above_max |
|
|
|
|
|
class AnyTextPipeline: |
|
def __init__(self, ckpt_path, font_path, device, use_fp16=True): |
|
self.cfg_path = ANYTEXT_CFG |
|
self.font_path = font_path |
|
self.use_fp16 = use_fp16 |
|
self.device = device |
|
|
|
self.font = ImageFont.truetype(font_path, size=60) |
|
self.model = create_model( |
|
self.cfg_path, |
|
device=self.device, |
|
use_fp16=self.use_fp16, |
|
) |
|
if self.use_fp16: |
|
self.model = self.model.half() |
|
if Path(ckpt_path).suffix == ".safetensors": |
|
state_dict = load_file(ckpt_path, device="cpu") |
|
else: |
|
state_dict = load_state_dict(ckpt_path, location="cpu") |
|
self.model.load_state_dict(state_dict, strict=False) |
|
self.model = self.model.eval().to(self.device) |
|
self.ddim_sampler = DDIMSampler(self.model, device=self.device) |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
negative_prompt: str, |
|
image: np.ndarray, |
|
masked_image: np.ndarray, |
|
num_inference_steps: int, |
|
strength: float, |
|
guidance_scale: float, |
|
height: int, |
|
width: int, |
|
seed: int, |
|
sort_priority: str = "y", |
|
callback=None, |
|
): |
|
""" |
|
|
|
Args: |
|
prompt: |
|
negative_prompt: |
|
image: |
|
masked_image: |
|
num_inference_steps: |
|
strength: |
|
guidance_scale: |
|
height: |
|
width: |
|
seed: |
|
sort_priority: x: left-right, y: top-down |
|
|
|
Returns: |
|
result: list of images in numpy.ndarray format |
|
rst_code: 0: normal -1: error 1:warning |
|
rst_info: string of error or warning |
|
|
|
""" |
|
set_seed(seed) |
|
str_warning = "" |
|
|
|
mode = "text-editing" |
|
revise_pos = False |
|
img_count = 1 |
|
ddim_steps = num_inference_steps |
|
w = width |
|
h = height |
|
strength = strength |
|
cfg_scale = guidance_scale |
|
eta = 0.0 |
|
|
|
prompt, texts = self.modify_prompt(prompt) |
|
if prompt is None and texts is None: |
|
return ( |
|
None, |
|
-1, |
|
"You have input Chinese prompt but the translator is not loaded!", |
|
"", |
|
) |
|
n_lines = len(texts) |
|
if mode in ["text-generation", "gen"]: |
|
edit_image = np.ones((h, w, 3)) * 127.5 |
|
elif mode in ["text-editing", "edit"]: |
|
if masked_image is None or image is None: |
|
return ( |
|
None, |
|
-1, |
|
"Reference image and position image are needed for text editing!", |
|
"", |
|
) |
|
if isinstance(image, str): |
|
image = cv2.imread(image)[..., ::-1] |
|
assert image is not None, f"Can't read ori_image image from{image}!" |
|
elif isinstance(image, torch.Tensor): |
|
image = image.cpu().numpy() |
|
else: |
|
assert isinstance( |
|
image, np.ndarray |
|
), f"Unknown format of ori_image: {type(image)}" |
|
edit_image = image.clip(1, 255) |
|
edit_image = check_channels(edit_image) |
|
|
|
|
|
|
|
h, w = edit_image.shape[:2] |
|
|
|
if masked_image is None: |
|
pos_imgs = np.zeros((w, h, 1)) |
|
if isinstance(masked_image, str): |
|
masked_image = cv2.imread(masked_image)[..., ::-1] |
|
assert ( |
|
masked_image is not None |
|
), f"Can't read draw_pos image from{masked_image}!" |
|
pos_imgs = 255 - masked_image |
|
elif isinstance(masked_image, torch.Tensor): |
|
pos_imgs = masked_image.cpu().numpy() |
|
else: |
|
assert isinstance( |
|
masked_image, np.ndarray |
|
), f"Unknown format of draw_pos: {type(masked_image)}" |
|
pos_imgs = 255 - masked_image |
|
pos_imgs = pos_imgs[..., 0:1] |
|
pos_imgs = cv2.convertScaleAbs(pos_imgs) |
|
_, pos_imgs = cv2.threshold(pos_imgs, 254, 255, cv2.THRESH_BINARY) |
|
|
|
pos_imgs = self.separate_pos_imgs(pos_imgs, sort_priority) |
|
if len(pos_imgs) == 0: |
|
pos_imgs = [np.zeros((h, w, 1))] |
|
if len(pos_imgs) < n_lines: |
|
if n_lines == 1 and texts[0] == " ": |
|
pass |
|
else: |
|
raise RuntimeError( |
|
f"{n_lines} text line to draw from prompt, not enough mask area({len(pos_imgs)}) on images" |
|
) |
|
elif len(pos_imgs) > n_lines: |
|
str_warning = f"Warning: found {len(pos_imgs)} positions that > needed {n_lines} from prompt." |
|
|
|
pre_pos = [] |
|
poly_list = [] |
|
for input_pos in pos_imgs: |
|
if input_pos.mean() != 0: |
|
input_pos = ( |
|
input_pos[..., np.newaxis] |
|
if len(input_pos.shape) == 2 |
|
else input_pos |
|
) |
|
poly, pos_img = self.find_polygon(input_pos) |
|
pre_pos += [pos_img / 255.0] |
|
poly_list += [poly] |
|
else: |
|
pre_pos += [np.zeros((h, w, 1))] |
|
poly_list += [None] |
|
np_hint = np.sum(pre_pos, axis=0).clip(0, 1) |
|
|
|
info = {} |
|
info["glyphs"] = [] |
|
info["gly_line"] = [] |
|
info["positions"] = [] |
|
info["n_lines"] = [len(texts)] * img_count |
|
gly_pos_imgs = [] |
|
for i in range(len(texts)): |
|
text = texts[i] |
|
if len(text) > max_chars: |
|
str_warning = ( |
|
f'"{text}" length > max_chars: {max_chars}, will be cut off...' |
|
) |
|
text = text[:max_chars] |
|
gly_scale = 2 |
|
if pre_pos[i].mean() != 0: |
|
gly_line = draw_glyph(self.font, text) |
|
glyphs = draw_glyph2( |
|
self.font, |
|
text, |
|
poly_list[i], |
|
scale=gly_scale, |
|
width=w, |
|
height=h, |
|
add_space=False, |
|
) |
|
gly_pos_img = cv2.drawContours( |
|
glyphs * 255, [poly_list[i] * gly_scale], 0, (255, 255, 255), 1 |
|
) |
|
if revise_pos: |
|
resize_gly = cv2.resize( |
|
glyphs, (pre_pos[i].shape[1], pre_pos[i].shape[0]) |
|
) |
|
new_pos = cv2.morphologyEx( |
|
(resize_gly * 255).astype(np.uint8), |
|
cv2.MORPH_CLOSE, |
|
kernel=np.ones( |
|
(resize_gly.shape[0] // 10, resize_gly.shape[1] // 10), |
|
dtype=np.uint8, |
|
), |
|
iterations=1, |
|
) |
|
new_pos = ( |
|
new_pos[..., np.newaxis] if len(new_pos.shape) == 2 else new_pos |
|
) |
|
contours, _ = cv2.findContours( |
|
new_pos, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE |
|
) |
|
if len(contours) != 1: |
|
str_warning = f"Fail to revise position {i} to bounding rect, remain position unchanged..." |
|
else: |
|
rect = cv2.minAreaRect(contours[0]) |
|
poly = np.int0(cv2.boxPoints(rect)) |
|
pre_pos[i] = ( |
|
cv2.drawContours(new_pos, [poly], -1, 255, -1) / 255.0 |
|
) |
|
gly_pos_img = cv2.drawContours( |
|
glyphs * 255, [poly * gly_scale], 0, (255, 255, 255), 1 |
|
) |
|
gly_pos_imgs += [gly_pos_img] |
|
else: |
|
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1)) |
|
gly_line = np.zeros((80, 512, 1)) |
|
gly_pos_imgs += [ |
|
np.zeros((h * gly_scale, w * gly_scale, 1)) |
|
] |
|
pos = pre_pos[i] |
|
info["glyphs"] += [self.arr2tensor(glyphs, img_count)] |
|
info["gly_line"] += [self.arr2tensor(gly_line, img_count)] |
|
info["positions"] += [self.arr2tensor(pos, img_count)] |
|
|
|
masked_img = ((edit_image.astype(np.float32) / 127.5) - 1.0) * (1 - np_hint) |
|
masked_img = np.transpose(masked_img, (2, 0, 1)) |
|
masked_img = torch.from_numpy(masked_img.copy()).float().to(self.device) |
|
if self.use_fp16: |
|
masked_img = masked_img.half() |
|
encoder_posterior = self.model.encode_first_stage(masked_img[None, ...]) |
|
masked_x = self.model.get_first_stage_encoding(encoder_posterior).detach() |
|
if self.use_fp16: |
|
masked_x = masked_x.half() |
|
info["masked_x"] = torch.cat([masked_x for _ in range(img_count)], dim=0) |
|
|
|
hint = self.arr2tensor(np_hint, img_count) |
|
cond = self.model.get_learned_conditioning( |
|
dict( |
|
c_concat=[hint], |
|
c_crossattn=[[prompt] * img_count], |
|
text_info=info, |
|
) |
|
) |
|
un_cond = self.model.get_learned_conditioning( |
|
dict( |
|
c_concat=[hint], |
|
c_crossattn=[[negative_prompt] * img_count], |
|
text_info=info, |
|
) |
|
) |
|
shape = (4, h // 8, w // 8) |
|
self.model.control_scales = [strength] * 13 |
|
samples, intermediates = self.ddim_sampler.sample( |
|
ddim_steps, |
|
img_count, |
|
shape, |
|
cond, |
|
verbose=False, |
|
eta=eta, |
|
unconditional_guidance_scale=cfg_scale, |
|
unconditional_conditioning=un_cond, |
|
callback=callback |
|
) |
|
if self.use_fp16: |
|
samples = samples.half() |
|
x_samples = self.model.decode_first_stage(samples) |
|
x_samples = ( |
|
(einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5) |
|
.cpu() |
|
.numpy() |
|
.clip(0, 255) |
|
.astype(np.uint8) |
|
) |
|
results = [x_samples[i] for i in range(img_count)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rst_code = 1 if str_warning else 0 |
|
return results, rst_code, str_warning |
|
|
|
def modify_prompt(self, prompt): |
|
prompt = prompt.replace("“", '"') |
|
prompt = prompt.replace("”", '"') |
|
p = '"(.*?)"' |
|
strs = re.findall(p, prompt) |
|
if len(strs) == 0: |
|
strs = [" "] |
|
else: |
|
for s in strs: |
|
prompt = prompt.replace(f'"{s}"', f" {PLACE_HOLDER} ", 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return prompt, strs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def separate_pos_imgs(self, img, sort_priority, gap=102): |
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img) |
|
components = [] |
|
for label in range(1, num_labels): |
|
component = np.zeros_like(img) |
|
component[labels == label] = 255 |
|
components.append((component, centroids[label])) |
|
if sort_priority == "y": |
|
fir, sec = 1, 0 |
|
elif sort_priority == "x": |
|
fir, sec = 0, 1 |
|
components.sort(key=lambda c: (c[1][fir] // gap, c[1][sec] // gap)) |
|
sorted_components = [c[0] for c in components] |
|
return sorted_components |
|
|
|
def find_polygon(self, image, min_rect=False): |
|
contours, hierarchy = cv2.findContours( |
|
image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE |
|
) |
|
max_contour = max(contours, key=cv2.contourArea) |
|
if min_rect: |
|
|
|
rect = cv2.minAreaRect(max_contour) |
|
poly = np.int0(cv2.boxPoints(rect)) |
|
else: |
|
|
|
epsilon = 0.01 * cv2.arcLength(max_contour, True) |
|
poly = cv2.approxPolyDP(max_contour, epsilon, True) |
|
n, _, xy = poly.shape |
|
poly = poly.reshape(n, xy) |
|
cv2.drawContours(image, [poly], -1, 255, -1) |
|
return poly, image |
|
|
|
def arr2tensor(self, arr, bs): |
|
arr = np.transpose(arr, (2, 0, 1)) |
|
_arr = torch.from_numpy(arr.copy()).float().to(self.device) |
|
if self.use_fp16: |
|
_arr = _arr.half() |
|
_arr = torch.stack([_arr for _ in range(bs)], dim=0) |
|
return _arr |
|
|