|
import torch |
|
from diffusers import AutoPipelineForImage2Image |
|
from PIL import Image, ImageDraw, ImageFont |
|
import requests |
|
from io import BytesIO |
|
import gradio as gr |
|
import gc |
|
import textwrap |
|
|
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
|
|
|
def image_to_template(generated_image, logo, button_text, punchline, theme_color): |
|
template_width = 1080 |
|
button_font_size = 20 |
|
punchline_font_size = 60 |
|
decoration_height = 20 |
|
margin = 40 |
|
|
|
punchline = textwrap.wrap(punchline, width=35) |
|
n_of_lines_punchline = len(punchline) |
|
|
|
generated_image = generated_image.convert("RGBA") |
|
logo = logo.convert("RGBA") |
|
|
|
|
|
image_width = template_width // 2 |
|
image_height = image_width * generated_image.height // generated_image.width |
|
image_shape = (image_width, image_height) |
|
|
|
|
|
logo_width = image_width // 3 |
|
logo_height = logo_width * logo.height // logo.width |
|
logo_shape = (logo_width, logo_height) |
|
|
|
|
|
button_font = ImageFont.truetype("./assets/Montserrat-Bold.ttf", button_font_size) |
|
punchline_font = ImageFont.truetype("./assets/Montserrat-Bold.ttf", punchline_font_size) |
|
|
|
|
|
button_width = template_width // 3 |
|
button_height = button_font_size * 3 |
|
|
|
|
|
template_height = ( |
|
image_height |
|
+ logo_height |
|
+ button_height |
|
+ n_of_lines_punchline * punchline_font_size |
|
+ (5 * margin) |
|
+ (2 * decoration_height) |
|
) |
|
|
|
|
|
logo_pos = ((template_width - logo_width) // 2, margin + decoration_height) |
|
image_pos = ( |
|
(template_width - image_width) // 2, |
|
logo_pos[1] + logo_height + margin, |
|
) |
|
|
|
|
|
top_decoration_pos = [ |
|
margin, |
|
-decoration_height // 2, |
|
template_width - margin, |
|
decoration_height // 2, |
|
] |
|
bottom_decoration_pos = [ |
|
margin, |
|
template_height - decoration_height // 2, |
|
template_width - margin, |
|
template_height + decoration_height // 2, |
|
] |
|
|
|
|
|
generated_image.thumbnail(image_shape, Image.ANTIALIAS) |
|
logo.thumbnail(logo_shape, Image.ANTIALIAS) |
|
background = Image.new("RGBA", (template_width, template_height), "WHITE") |
|
|
|
mask = Image.new("L", generated_image.size, 0) |
|
draw = ImageDraw.Draw(mask) |
|
draw.rounded_rectangle((0, 0) + generated_image.size, 20, fill=255) |
|
generated_image.putalpha(mask) |
|
|
|
background.paste(logo, logo_pos, logo) |
|
background.paste(generated_image, image_pos, generated_image) |
|
|
|
draw = ImageDraw.Draw(background) |
|
|
|
draw.rounded_rectangle(bottom_decoration_pos, radius=20, fill=theme_color) |
|
draw.rounded_rectangle(top_decoration_pos, radius=20, fill=theme_color) |
|
|
|
text_heights = [] |
|
for line in punchline: |
|
text_width, text_height = draw.textsize(line, font=punchline_font) |
|
punchline_pos = ( |
|
(template_width - text_width) // 2, |
|
image_pos[1] + generated_image.height + margin + sum(text_heights), |
|
) |
|
draw.text(punchline_pos, line, fill=theme_color, font=punchline_font) |
|
text_heights.append(text_height) |
|
|
|
|
|
button_text_width, button_text_height = draw.textsize(button_text, font=button_font) |
|
button_shape = [ |
|
((template_width - button_width) // 2, punchline_pos[1] + text_height + margin), |
|
( |
|
(template_width + button_width) // 2, |
|
punchline_pos[1] + text_height + margin + button_height, |
|
), |
|
] |
|
draw.rounded_rectangle(button_shape, radius=20, fill=theme_color) |
|
|
|
button_text_pos = ( |
|
(template_width - button_text_width) // 2, |
|
button_shape[0][1] + (button_height - button_text_height) // 2, |
|
) |
|
draw.text(button_text_pos, button_text, fill="white", font=button_font) |
|
|
|
return background |
|
|
|
|
|
def generate_template( |
|
initial_image, logo, prompt, button_text, punchline, image_color, theme_color |
|
): |
|
pipeline = AutoPipelineForImage2Image.from_pretrained( |
|
"./models/kandinsky-2-2-decoder", |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
) |
|
|
|
|
|
pipeline.enable_model_cpu_offload() |
|
|
|
prompt = f"{prompt}, include the color {image_color}" |
|
negative_prompt = "low quality, bad quality, blurry, unprofessional" |
|
|
|
generated_image = pipeline( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=initial_image, |
|
height=256, |
|
width=256, |
|
).images[0] |
|
|
|
template_image = image_to_template( |
|
generated_image, logo, button_text, punchline, theme_color |
|
) |
|
|
|
|
|
del pipeline |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
return template_image |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_template, |
|
inputs=[ |
|
gr.Image(type="pil", label="Initial Image"), |
|
gr.Image(type="pil", label="Logo"), |
|
gr.Textbox(label="Prompt"), |
|
gr.Textbox(label="Button Text"), |
|
gr.Textbox(label="Punchline"), |
|
gr.ColorPicker(label="Image Color"), |
|
gr.ColorPicker(label="Theme Color"), |
|
], |
|
outputs=[gr.Image(type="pil")], |
|
title="Ad Template Generation Using Diffusion Models Demo", |
|
description="Generate ad template based on your inputs using a trained model.", |
|
concurrency_limit=2, |
|
examples=[ |
|
[ |
|
"./assets/city_image.jpg", |
|
"./assets/logo.png", |
|
"Big bank building finance", |
|
"Discover More!", |
|
"We Maximize Risk-Adusted Returns for Our Customers", |
|
"#00FF00", |
|
"#0000FF", |
|
] |
|
], |
|
) |
|
|
|
|
|
iface.launch(debug=True) |
|
|