demo-margiela / app.py
sessex's picture
Update app.py
d7fe955 verified
import gradio as gr
import requests
import os
from PIL import Image
from io import BytesIO
from tqdm import tqdm
import time
import cv2
import numpy as np
import webcolors
import json
import re
from gradio_client import Client
import ast
import spaces
from profanityfilter import ProfanityFilter
import torch
from diffusers import DiffusionPipeline, AutoencoderKL, DPMSolverMultistepScheduler
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
# Optimized kernels
# torch._inductor.config.conv_1x1_as_mm = True
# torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.epilogue_fusion = False
# torch._inductor.config.coordinate_descent_check_all_directions = True
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
text2img_pipe = DiffusionPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to(device)
text2img_pipe.scheduler = DPMSolverMultistepScheduler.from_config(text2img_pipe.scheduler.config)
text2img_pipe.enable_model_cpu_offload()
text2img_pipe.unet.to(memory_format=torch.channels_last) # in-place operation
# pipeline_img2img.enable_sequential_cpu_offload()
text2img_pipe.enable_vae_tiling()
text2img_pipe.enable_attention_slicing()
# Text2img inference optimizations
# text2img_pipe.fuse_qkv_projections()
# text2img_pipe.unet.to(memory_format=torch.channels_last)
# text2img_pipe.vae.to(memory_format=torch.channels_last)
# text2img_pipe.unet = torch.compile(text2img_pipe.unet, mode="max-autotune", fullgraph=True)
# text2img_pipe.vae.decode = torch.compile(text2img_pipe.vae.decode, mode="max-autotune", fullgraph=True)
text2img_pipe.load_lora_weights('sessex/tabi-0-LoRA')
# Load merged models
# from peft import PeftModel
# from diffusers import UNet2DConditionModel
# base_unet = UNet2DConditionModel.from_pretrained(
# model_id, subfolder="unet", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
# ).to(device)
# tabi_id = "sessex/tabi-0-peft-model"
# margiela_id = "sessex/margiela_style_peft_model-new"
# model = PeftModel.from_pretrained(base_unet, tabi_id, use_safetensors=True, subfolder="tabi", adapter_name="tabi")
# model.load_adapter(margiela_id, use_safetensors=True, subfolder="margiela", adapter_name="margiela")
# model.add_weighted_adapter(
# adapters=["tabi", "margiela"],
# weights=[1.0, 0.4],
# combination_type="ties_svd",
# adapter_name="tabi-margiela",
# density=0.5
# )
# model.set_adapters("tabi-margiela")
# model = model.to(dtype=torch.float16, device="cuda")
# merged_pipe = DiffusionPipeline.from_pretrained(
# model_id, unet=model, variant="fp16", torch_dtype=torch.float16,
# ).to(device)
# merged_pipe.enable_model_cpu_offload()
# merged_pipe.unet.to(memory_format=torch.channels_last) # in-place operation
# merged_pipe.enable_vae_tiling()
# merged_pipe.enable_attention_slicing()
from PIL import Image
def resize_and_place(image, new_width, new_height, output_path):
# Open the image
# original_image = Image.new(image)
# Resize the image
resized_image = image.resize((new_width, new_height))
# Create a new white background image
background = Image.new('RGB', (1200, 1500), color='white')
# Calculate the position to paste the resized image onto the white background
x_offset = (1200 - resized_image.width) // 2
y_offset = 48
# Paste the resized image onto the white background at the specified position
background.paste(resized_image, (x_offset, y_offset))
# Save the new image
background.save(output_path)
return background
# Initialize the profanity filter
pf = ProfanityFilter()
def filter_inappropriate(input_text):
# Filter out inappropriate words
pf.censor_char = ' '
filtered_text = pf.censor(input_text)
return filtered_text.strip()
@spaces.GPU(enable_queue=True)
def text2img_inference(prompt, neg_prompt, repo):
gr.Info('Image generation request sent')
with torch.no_grad():
image = text2img_pipe(
prompt=prompt,
negative_prompt=neg_prompt,
width=1024,
height=1024,
num_inference_steps=20,
guidance_scale=7.5
).images[0]
return image
# if repo == '3.1':
# gr.Info('Merged model diffusion in progress')
# image = merged_pipe(
# prompt,
# negative_prompt=neg_prompt,
# num_inference_steps=30,
# ).images[0]
# else:
# gr.Info('Standalone model diffusion in progress')
# image = text2img_pipe(
# prompt=prompt,
# negative_prompt=neg_prompt,
# # width=1024,
# # height=1024,
# num_inference_steps=20,
# guidance_scale=7.5
# ).images[0]
# find the closest color name to rgb value
def closest_color(rgb_color):
min_colors = {}
for key, name in webcolors.CSS3_HEX_TO_NAMES.items():
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
rd = (r_c - rgb_color[0]) ** 2
gd = (g_c - rgb_color[1]) ** 2
bd = (b_c - rgb_color[2]) ** 2
min_colors[(rd + gd + bd)] = name
return min_colors[min(min_colors.keys())]
def get_dominant_colors(img_filepath):
# Load the image from file path
img_data = Image.open(img_filepath)
# Convert the image to a NumPy array
img = np.array(img_data)
# k-means clustering to create palette of most dominant n_colors
pixels = np.float32(img.reshape(-1, 3))
n_colors = 2
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, .1)
flags = cv2.KMEANS_RANDOM_CENTERS
_, labels, palette = cv2.kmeans(pixels, n_colors, None, criteria, 10, flags)
# get names of dominant colors
dominant_colors = []
for color in palette:
color_name = closest_color(color)
dominant_colors.append(color_name)
return dominant_colors
# # Precompute color dictionary
# css3_hex_to_rgb = {key: webcolors.hex_to_rgb(key) for key in webcolors.CSS3_HEX_TO_NAMES}
# # find the closest color name to rgb value
# def closest_color(rgb_color):
# min_distance = float('inf')
# closest_name = None
# for key, value in css3_hex_to_rgb.items():
# r_c, g_c, b_c = value
# rd = (r_c - rgb_color[0]) ** 2
# gd = (g_c - rgb_color[1]) ** 2
# bd = (b_c - rgb_color[2]) ** 2
# distance = rd + gd + bd
# if distance < min_distance:
# min_distance = distance
# closest_name = key
# return webcolors.CSS3_HEX_TO_NAMES[closest_name]
# def get_dominant_colors(img_filepath):
# # Load the image from file path
# img_data = Image.open(img_filepath)
# # Convert the image to a NumPy array
# img = np.array(img_data)
# # k-means clustering to create palette of most dominant n_colors
# pixels = np.float32(img.reshape(-1, 3))
# n_colors = 5
# criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, .1)
# flags = cv2.KMEANS_RANDOM_CENTERS
# _, labels, palette = cv2.kmeans(pixels, n_colors, None, criteria, 10, flags)
# # get names of dominant colors
# dominant_colors = [closest_color(color) for color in palette]
# return dominant_colors
def get_caption(image):
kosmos2_client = Client("https://ydshieh-kosmos-2.hf.space/")
kosmos2_result = kosmos2_client.predict(
image, # str (filepath or URL to image) in 'Test Image' Image component
"Detailed", # str in 'Description Type' Radio component
fn_index=4
)
print(f"KOSMOS2 RETURNS: {kosmos2_result}")
print(f"OBJECTS: {kosmos2_result[2]}")
# Convert the string to a list of tuples
identified_objects_list = ast.literal_eval(kosmos2_result[2])
# Extract words from the list of tuples
words_list = []
for item in identified_objects_list:
words_list.append(item[0])
identified_objects = ", ".join(words_list)
print(identified_objects)
with open(kosmos2_result[1], 'r') as f:
data = json.load(f)
reconstructed_sentence = []
for sublist in data:
reconstructed_sentence.append(sublist[0])
full_sentence = ' '.join(reconstructed_sentence)
#print(full_sentence)
# Find the pattern matching the expected format ("Describe this image in detail:" followed by optional space and then the rest)...
pattern = r'^Describe this image in detail:\s*(.*)$'
# Apply the regex pattern to extract the description text.
match = re.search(pattern, full_sentence)
if match:
description = match.group(1)
print(description)
else:
print("Unable to locate valid description.")
# Find the last occurrence of "."
#last_period_index = full_sentence.rfind('.')
# Truncate the string up to the last period
#truncated_caption = full_sentence[:last_period_index + 1]
# print(truncated_caption)
#print(f"\n—\nIMAGE CAPTION: {truncated_caption}")
return description
# return identified_objects
def get_caption_from_MD(image_in):
client = Client("https://vikhyatk-moondream1.hf.space/")
result = client.predict(
image_in, # filepath in 'image' Image component
"What colors and patterns appear in this photo?",
# "Describe the colors, patterns, aesthetic, artistic style, and objects in this photo", # str in 'Question' Textbox component
api_name="/answer_question"
)
print(result)
return result
def get_image_keywords(image, captioner):
# get img2text description
caption = get_caption(image) if captioner == 'kosmos' else get_caption_from_MD(image)
# get colors
# colors_list = get_dominant_colors(image)
# colors = ", ".join(colors_list)
return caption, ''
import re
from transformers import pipeline
pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto")
@spaces.GPU(enable_queue=True)
def construct_prompt(image_caption, image_colors, user_input):
agent_maker_sys = f"""
You are a AI whose job is to help users create their own custom shoe image which will reflect the colors, characteristics, or aesthetics from an image described by users.
In particular, you need to respond succintly and write a prompt for an image generation model. The response must include to the word "mm-tabi" which will trigger the style of shoe.
The response should avoid any descriptions of man or woman and don't include any articles of clothing or accessories from Caption.
The response should always start with "surreal photo of mm-tabi boot with split toe".
The response should always end with "still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k".
The response should only use one or two aspects from the Caption provided by the user that could easily be applied to a still life scene or characteristic of the shoe, like color or texture or an object.
For example, if a user says,
"Keywords: California dogs sunshine shopping
/n Caption: The photo features a woman wearing a blue sweater with a red and white design. The sweater is a prominent feature in the image, and it is the main focus of the scene. The background is plain, with no other colors or patterns visible. The woman is standing in front of a building, which serves as a backdrop for the photo."
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi boot with split toe, surrounded by California dogs sunshine shopping, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k, blue"
If a user says,
"Keywords: Bunny
/n Caption: The photo features a young woman wearing a black sweatshirt with a red and white pattern. She is standing in a large, empty room, which appears to be a mall or a similar public space. The room has a white ceiling and is decorated with various colors and patterns, creating a visually interesting environment. There is also a handbag visible in the scene, placed close to the woman."
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi boot with split toe, surrounded by bunny, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k, red, white"
If a user says,
"Keywords: Sun
/n Caption: The photo features a woman wearing a striped shirt, which has a combination of black, white, and gray colors. She is also wearing glasses, and her smile adds a positive touch to the image. Additionally, she is holding a cell phone in her hand, which is being photographed. The background of the photo is plain, with no visible patterns or colors, allowing the focus to be on the woman and her attire.
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi boot with split toe, surrounded by sun, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k"
Here's another example, if a user says,
"Keywords: Chitose Abe
/n Caption: The photo features a young man wearing a black jacket and a white baseball cap. He is smiling and posing for the camera, with a tan and black jacket and a white cap. The man is carrying a bag, which is visible in the image. The background of the photo is white, and there is a person standing behind the young man, possibly a friend or a passerby.
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi boot with split toe, surrounded by Chitose-Abe inspired, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k, black, white"
Here's another example, if a user says,
"Keywords: Painted
/n Caption: The photo features a young woman with black hair, wearing a black shirt. She is holding a cupcake with green frosting and red candy on top. The cupcake itself has a green frosting and red candy, which adds a pop of color to the scene. The overall image showcases a combination of black, green, and red colors, along with the woman’s smiling expression, creating a visually appealing and vibrant scene.
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi boot with split toe, painted, surrounded by cupcake, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k, green, red"
Here's another example, if a user says,
"Keywords: Sneaker
/n Caption: The photo features a woman wearing a black top with a white bottom. The top has a black collar, and the woman is posing in front of a large building. The building's interior is decorated with white tiles, which create a contrasting background for the woman's outfit.
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi sneaker with split toe, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k, black"
Here's another example, if a user says,
"Keywords: Ruby black
/n Caption: The photo features a woman wearing a dark blue sweater with a black and white checkered pattern. She is posing in front of a large, white building, which could be a hotel or a mall. The background is plain, with no visible patterns or colors. The woman is holding a handbag, which is also black and white in design."
, provide immediately an image prompt that describes a still life photo of a shoe corresponding to the keywords, color, and objects or stylistic elements from the caption provided.
Immediately STOP after that. It should be in this format:
"surreal photo of mm-tabi boot with split toe, dark blue, surrounded by ruby black, still life in the style of retrofuturism, unconventional, dreamy, fantasy, digital video distortion, lens aberration, highly detailed, hd, 8k"""
instruction = f"""
<|system|>
{agent_maker_sys}</s>
<|user|>
"""
prompt = f"{instruction.strip()}\n Keywords: {user_input} \n Caption: {image_caption}</s>"
print(f"PROMPT: \n Keywords: {user_input} \n Caption: {image_caption}")
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
return cleaned_text.lstrip("\n")
def generate_image(prompt, input_image, captioner, neg_prompt, model_id):
# generate keywords from image
gr.Info('Starting to generate caption for input image')
img_caption, img_colors = get_image_keywords(input_image, captioner)
# process + filter user input
# remove inappropriate language...
gr.Info('Processing keywords for inappropriate language')
user_input = filter_inappropriate(prompt)
# construct prompt (keywords, user input, trigger words)
# full_prompt = f"still life photo of TOK boot, {user_input}, {img_caption}, Surreal, Unreal, Digital, different universe, Dreamy, Fantasy, Otherworldly, Unconventional, Unexpected, Irreverent, Pimped, in the style of surrealism, {img_colors}, highly detailed, hd, 8k"
full_prompt = construct_prompt(img_caption, img_colors, user_input)
if model_id == '3.1':
full_prompt = full_prompt + ", in the style of maison-margiela"
print(f"FULL PROMPT: {full_prompt}")
# text2img generation with full prompt construction
gr.Info('Starting image generation on constructed prompt')
image = text2img_inference(full_prompt, neg_prompt, model_id)
watermarkable_image = resize_and_place(image, 1024, 1024, 'custom_tabi.jpg')
return image, full_prompt, img_caption, watermarkable_image
def skip_img_upload(image_caption, keywords, neg_prompt):
full_prompt = construct_prompt(image_caption, '', keywords)
print(f"FULL PROMPT: {full_prompt}")
image = text2img_inference(full_prompt, neg_prompt, '')
return full_prompt, image
with gr.Blocks() as gradio_app:
gr.Markdown(
"""
# Generate a custom Maison Margiela Tabi
Enter keywords and upload a photo.
""")
with gr.Row():
with gr.Column():
prompt = gr.Text(label="User Keywords")
input_image = gr.Image(label="Input Image", type='filepath')
generate_btn = gr.Button("Generate")
image_caption = gr.Text(label="Image Caption")
with gr.Accordion(open=False, label="Advanced Options"):
model_id = gr.Dropdown(
label="LoRA Model",
info="Change between LoRA models, which have been trained on different image datasets",
choices=[
('Model 0', 'sessex/tabi-0-LoRA'),
# ('Model 1 (Ecomm)', 'sessex/tabi_LoRA'),
# ('Model 2 (Archive)', 'sessex/mm-tabi-whitebg_LoRA'),
# ('Model 3.1 (Style merged with Ecomm Tabi)', '3.1'),
# ('Model 3.0 (Style merged with Model 0)', '3.1')
],
value='sessex/tabi-0-LoRA'
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="low quality",
value="nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry",
)
captioner = gr.Dropdown(
label="Caption Model",
info="To test which captioning model produces better description of input image",
choices=[('Kosmos2', 'kosmos'), ('Moonteam', 'moonteam')],
value='moonteam'
)
# include_colors = gr.Checkbox(
# label="Caption Model",
# description="To test if adding colors to final prompt improves association to input image",
# value=True
# )
with gr.Column():
construct_btn = gr.Button("Retry with Edited Final Prompt")
keywords = gr.Text(label="Identified Keywords from Input Image", interactive=False)
final_prompt = gr.Text(label="Final Prompt", interactive=True)
retry_btn = gr.Button("Retry with Edited Final Prompt")
output_image = gr.Image(label="Generated Tabi")
watermarkable_image = gr.Image(label="Resized Image to Return")
generate_btn.click(generate_image, [prompt, input_image, captioner, negative_prompt, model_id], [output_image, final_prompt, keywords, watermarkable_image])
retry_btn.click(text2img_inference, [final_prompt, negative_prompt, model_id], [output_image])
construct_btn.click(skip_img_upload, [image_caption, prompt, negative_prompt], [final_prompt, output_image])
gradio_app.launch(debug=True)