Spaces:
Runtime error
Runtime error
import os | |
import json | |
import torch | |
import random | |
import numpy as np | |
COLORS = { | |
'brown': [165, 42, 42], | |
'red': [255, 0, 0], | |
'pink': [253, 108, 158], | |
'orange': [255, 165, 0], | |
'yellow': [255, 255, 0], | |
'purple': [128, 0, 128], | |
'green': [0, 128, 0], | |
'blue': [0, 0, 255], | |
'white': [255, 255, 255], | |
'gray': [128, 128, 128], | |
'black': [0, 0, 0], | |
} | |
def seed_everything(seed): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
def hex_to_rgb(hex_string, return_nearest_color=False, device='cuda'): | |
r""" | |
Covert Hex triplet to RGB triplet. | |
""" | |
# Remove '#' symbol if present | |
hex_string = hex_string.lstrip('#') | |
# Convert hex values to integers | |
red = int(hex_string[0:2], 16) | |
green = int(hex_string[2:4], 16) | |
blue = int(hex_string[4:6], 16) | |
rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255. | |
if return_nearest_color: | |
nearest_color = find_nearest_color(rgb) | |
return rgb.to(device), nearest_color | |
return rgb.to(device) | |
def find_nearest_color(rgb): | |
r""" | |
Find the nearest neighbor color given the RGB value. | |
""" | |
if isinstance(rgb, list) or isinstance(rgb, tuple): | |
rgb = torch.FloatTensor(rgb)[None, :, None, None]/255. | |
color_distance = torch.FloatTensor([np.linalg.norm( | |
rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()]) | |
nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()] | |
return nearest_color | |
def font2style(font, device='cuda'): | |
r""" | |
Convert the font name to the style name. | |
""" | |
return {'mirza': 'Claud Monet, impressionism, oil on canvas', | |
'roboto': 'Ukiyoe', | |
'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq', | |
'sofia': 'Pop Art, masterpiece, andy warhol', | |
'slabo': 'Vincent Van Gogh', | |
'inconsolata': 'Pixel Art, 8 bits, 16 bits', | |
'ubuntu': 'Rembrandt', | |
'Monoton': 'neon art, colorful light, highly details, octane render', | |
'Akronim': 'Abstract Cubism, Pablo Picasso', }[font] | |
def parse_json(json_str, device): | |
r""" | |
Convert the JSON string to attributes. | |
""" | |
# initialze region-base attributes. | |
base_text_prompt = '' | |
style_text_prompts = [] | |
footnote_text_prompts = [] | |
footnote_target_tokens = [] | |
color_text_prompts = [] | |
color_rgbs = [] | |
color_names = [] | |
size_text_prompts_and_sizes = [] | |
# parse the attributes from JSON. | |
prev_style = None | |
prev_color_rgb = None | |
use_grad_guidance = False | |
for span in json_str['ops']: | |
text_prompt = span['insert'].rstrip('\n') | |
base_text_prompt += span['insert'].rstrip('\n') | |
if text_prompt == ' ': | |
continue | |
if 'attributes' in span: | |
if 'font' in span['attributes']: | |
style = font2style(span['attributes']['font']) | |
if prev_style == style: | |
prev_text_prompt = style_text_prompts[-1].split('in the style of')[ | |
0] | |
style_text_prompts[-1] = prev_text_prompt + \ | |
' ' + text_prompt + f' in the style of {style}' | |
else: | |
style_text_prompts.append( | |
text_prompt + f' in the style of {style}') | |
prev_style = style | |
else: | |
prev_style = None | |
if 'link' in span['attributes']: | |
footnote_text_prompts.append(span['attributes']['link']) | |
footnote_target_tokens.append(text_prompt) | |
font_size = 1 | |
if 'size' in span['attributes'] and 'strike' not in span['attributes']: | |
font_size = float(span['attributes']['size'][:-2])/3. | |
elif 'size' in span['attributes'] and 'strike' in span['attributes']: | |
font_size = -float(span['attributes']['size'][:-2])/3. | |
elif 'size' not in span['attributes'] and 'strike' not in span['attributes']: | |
font_size = 1 | |
if 'color' in span['attributes']: | |
use_grad_guidance = True | |
color_rgb, nearest_color = hex_to_rgb( | |
span['attributes']['color'], True, device=device) | |
if prev_color_rgb == color_rgb: | |
prev_text_prompt = color_text_prompts[-1] | |
color_text_prompts[-1] = prev_text_prompt + \ | |
' ' + text_prompt | |
else: | |
color_rgbs.append(color_rgb) | |
color_names.append(nearest_color) | |
color_text_prompts.append(text_prompt) | |
if font_size != 1: | |
size_text_prompts_and_sizes.append([text_prompt, font_size]) | |
return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\ | |
color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance | |
def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts, | |
footnote_target_tokens, color_text_prompts, color_names): | |
r""" | |
Algorithm 1 in the paper. | |
""" | |
region_text_prompts = [] | |
region_target_token_ids = [] | |
base_tokens = model.tokenizer._tokenize(base_text_prompt) | |
# process the style text prompt | |
for text_prompt in style_text_prompts: | |
region_text_prompts.append(text_prompt) | |
region_target_token_ids.append([]) | |
style_tokens = model.tokenizer._tokenize( | |
text_prompt.split('in the style of')[0]) | |
for style_token in style_tokens: | |
region_target_token_ids[-1].append( | |
base_tokens.index(style_token)+1) | |
# process the complementary text prompt | |
for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens): | |
region_target_token_ids.append([]) | |
region_text_prompts.append(footnote_text_prompt) | |
style_tokens = model.tokenizer._tokenize(text_prompt) | |
for style_token in style_tokens: | |
region_target_token_ids[-1].append( | |
base_tokens.index(style_token)+1) | |
# process the color text prompt | |
for color_text_prompt, color_name in zip(color_text_prompts, color_names): | |
region_target_token_ids.append([]) | |
region_text_prompts.append(color_name+' '+color_text_prompt) | |
style_tokens = model.tokenizer._tokenize(color_text_prompt) | |
for style_token in style_tokens: | |
region_target_token_ids[-1].append( | |
base_tokens.index(style_token)+1) | |
# process the remaining tokens without any attributes | |
region_text_prompts.append(base_text_prompt) | |
region_target_token_ids_all = [ | |
id for ids in region_target_token_ids for id in ids] | |
target_token_ids_rest = [id for id in range( | |
1, len(base_tokens)+1) if id not in region_target_token_ids_all] | |
region_target_token_ids.append(target_token_ids_rest) | |
region_target_token_ids = [torch.LongTensor( | |
obj_token_id) for obj_token_id in region_target_token_ids] | |
return region_text_prompts, region_target_token_ids, base_tokens | |
def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes): | |
r""" | |
Control the token impact using font sizes. | |
""" | |
word_pos = [] | |
font_sizes = [] | |
for text_prompt, font_size in size_text_prompts_and_sizes: | |
size_tokens = model.tokenizer._tokenize(text_prompt) | |
for size_token in size_tokens: | |
word_pos.append(base_tokens.index(size_token)+1) | |
font_sizes.append(font_size) | |
if len(word_pos) > 0: | |
word_pos = torch.LongTensor(word_pos).to(model.device) | |
font_sizes = torch.FloatTensor(font_sizes).to(model.device) | |
else: | |
word_pos = None | |
font_sizes = None | |
text_format_dict = { | |
'word_pos': word_pos, | |
'font_size': font_sizes, | |
} | |
return text_format_dict | |
def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict, | |
guidance_start_step=999, color_guidance_weight=1): | |
r""" | |
Control the token impact using font sizes. | |
""" | |
color_target_token_ids = [] | |
for text_prompt in color_text_prompts: | |
color_target_token_ids.append([]) | |
color_tokens = model.tokenizer._tokenize(text_prompt) | |
for color_token in color_tokens: | |
color_target_token_ids[-1].append(base_tokens.index(color_token)+1) | |
color_target_token_ids_all = [ | |
id for ids in color_target_token_ids for id in ids] | |
color_target_token_ids_rest = [id for id in range( | |
1, len(base_tokens)+1) if id not in color_target_token_ids_all] | |
color_target_token_ids.append(color_target_token_ids_rest) | |
color_target_token_ids = [torch.LongTensor( | |
obj_token_id) for obj_token_id in color_target_token_ids] | |
text_format_dict['target_RGB'] = color_rgbs | |
text_format_dict['guidance_start_step'] = guidance_start_step | |
text_format_dict['color_guidance_weight'] = color_guidance_weight | |
return text_format_dict, color_target_token_ids | |