ortha / app.py
ujin-song's picture
Update app.py
ab4cb7b verified
import gradio as gr
import numpy as np
import random
import torch
import io, json
from PIL import Image
import os.path
from weight_fusion import compose_concepts
from regionally_controlable_sampling import sample_image, build_model, prepare_text
device = "cuda" if torch.cuda.is_available() else "cpu"
power_device = "GPU" if torch.cuda.is_available() else "CPU"
MAX_SEED = 100_000
def generate(region1_concept,
region2_concept,
prompt,
pose_image_name,
region1_prompt,
region2_prompt,
negative_prompt,
region_neg_prompt,
seed,
randomize_seed,
alpha,
sketch_adaptor_weight,
keypose_adaptor_weight
):
if region1_concept==region2_concept:
raise gr.Error("Please choose two different characters for merging weights.")
if len(pose_image_name)==0:
raise gr.Error("Please select one spatial condition!")
if len(region1_prompt)==0 or len(region1_prompt)==0:
raise gr.Error("Your regional prompt cannot be empty.")
if len(prompt)==0:
raise gr.Error("Your global prompt cannot be empty.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
pretrained_model = merge(region1_concept, region2_concept, alpha)
with open('multi-concept/pose_data/pose.json') as f:
d = json.load(f)
pose_image = {os.path.basename(obj['img_dir']):obj for obj in d}[pose_image_name]
# pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
print(pose_image)
keypose_condition = pose_image['img_dir']
region1 = pose_image['region1']
region2 = pose_image['region2']
region_pos_prompt = "high resolution, best quality, highly detailed, sharp focus, sophisticated, Good anatomy, Clear facial features, Proportional body, Detailed clothing, Smooth textures"
region1_prompt = f'<{region1_concept}1> <{region1_concept}2>, {region1_prompt}, {region_pos_prompt}'
region2_prompt = f'<{region2_concept}1> <{region2_concept}2>, {region2_prompt}, {region_pos_prompt}'
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
print(prompt_rewrite)
prompt+=", Disney style photo, High resolution, best quality, highly detailed, expressive,"
result = infer(pretrained_model,
prompt,
prompt_rewrite,
negative_prompt,
seed,
keypose_condition,
keypose_adaptor_weight,
# sketch_condition,
# sketch_adaptor_weight,
)
return result
def merge(concept1, concept2, alpha):
device = "cuda" if torch.cuda.is_available() else "cpu"
c1, c2 = sorted([concept1, concept2])
assert c1!=c2
merge_name = c1+'_'+c2
save_path = f'experiments/multi-concept/{merge_name}--{int(alpha*10)}'
if os.path.isdir(save_path):
print(f'{save_path} already exists. Collecting merged weights from existing weights...')
else:
os.makedirs(save_path)
json_path = os.path.join(save_path,'merge_config.json')
# alpha = 1.8
data = [
{
"lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth",
"unet_alpha": alpha,
"text_encoder_alpha": alpha,
"concept_name": f"<{c1}1> <{c1}2>"
},
{
"lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth",
"unet_alpha": alpha,
"text_encoder_alpha": alpha,
"concept_name": f"<{c2}1> <{c2}2>"
}
]
with io.open(json_path,'w',encoding='utf8') as outfile:
json.dump(data, outfile, indent = 4, ensure_ascii=False)
compose_concepts(
concept_cfg=json_path,
optimize_textenc_iters=500,
optimize_unet_iters=50,
pretrained_model_path="nitrosocke/mo-di-diffusion",
save_path=save_path,
suffix='base',
device=device,
)
print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n')
modelbase_path = os.path.join(save_path,'combined_model_base')
assert os.path.isdir(modelbase_path)
# save_path = 'experiments/multi-concept/elsa_moana_weight18/combined_model_base'
return modelbase_path
def infer(pretrained_model,
prompt,
prompt_rewrite,
negative_prompt='',
seed=16141,
keypose_condition=None,
keypose_adaptor_weight=1.0,
sketch_condition=None,
sketch_adaptor_weight=0.0,
region_sketch_adaptor_weight='',
region_keypose_adaptor_weight=''
):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pipe = build_model(pretrained_model, device)
if sketch_condition is not None and os.path.exists(sketch_condition):
sketch_condition = Image.open(sketch_condition).convert('L')
width_sketch, height_sketch = sketch_condition.size
print('use sketch condition')
else:
sketch_condition, width_sketch, height_sketch = None, 0, 0
print('skip sketch condition')
if keypose_condition is not None and os.path.exists(keypose_condition):
keypose_condition = Image.open(keypose_condition).convert('RGB')
width_pose, height_pose = keypose_condition.size
print('use pose condition')
else:
keypose_condition, width_pose, height_pose = None, 0, 0
print('skip pose condition')
if width_sketch != 0 and width_pose != 0:
assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size'
width, height = max(width_pose, width_sketch), max(height_pose, height_sketch)
kwargs = {
'sketch_condition': sketch_condition,
'keypose_condition': keypose_condition,
'height': height,
'width': width,
}
prompts = [prompt]
prompts_rewrite = [prompt_rewrite]
input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)]
save_prompt = input_prompt[0][0]
print(save_prompt)
image = sample_image(
pipe,
input_prompt=input_prompt,
input_neg_prompt=[negative_prompt] * len(input_prompt),
generator=torch.Generator(device).manual_seed(seed),
sketch_adaptor_weight=sketch_adaptor_weight,
region_sketch_adaptor_weight=region_sketch_adaptor_weight,
keypose_adaptor_weight=keypose_adaptor_weight,
region_keypose_adaptor_weight=region_keypose_adaptor_weight,
**kwargs)
return image[0]
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
return evt.value['image']['orig_name']
examples_context = [
'walking on the busy streets of New York',
'in the forest',
'in the style of cyberpunk'
]
examples_region1 = ['In a casual t-shirt', 'wearing jeans']
examples_region2 = ['smiling, wearing a yellow hoodie', 'In a baseball uniform']
with open('multi-concept/pose_data/pose.json') as f:
d = json.load(f)
pose_image_list = [(obj['img_id'],obj['img_dir']) for obj in d]
css="""
#col-container {
margin: 0 auto;
max-width: 600px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(f"""
# Orthogonal Adaptation
Describe your world with a **[ 🪄 Text Prompts ]** (global and regional prompts) and choose two characters to merge.
Select their **[ 👯 Poses ]** (spatial conditions) for regionally controllable sampling to generate a unique image using our model.
Let your creativity run wild! (Currently running on : {power_device} )
""")
with gr.Row():
with gr.Column(elem_id="col-container"):
# gr.Markdown(f"""
# ### 🪄 Global and Region prompts
# """)
# with gr.Group():
with gr.Tab('🪄 Text Prompts'):
prompt = gr.Text(
label="ContextPrompt",
show_label=False,
max_lines=1,
placeholder="Enter your global context prompt",
container=False,
)
with gr.Row():
concept_list = ["Elsa", "Moana", "Woody", "Rapunzel", "Elastigirl",
"Linguini", "Raya", "Hiro", "Mirabel", "Miguel"]
region1_concept = gr.Dropdown(
concept_list,
label="Character 1",
# info="Will add more characters later!"
)
region2_concept = gr.Dropdown(
concept_list,
label="Character 2",
# info="Will add more characters later!"
)
with gr.Row():
region1_prompt = gr.Textbox(
label="Region1 Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your regional prompt for character 1",
container=False,
)
region2_prompt = gr.Textbox(
label="Region2 Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your regional prompt for character 2",
container=False,
)
gr.Examples(
label = 'Global Prompt example',
examples = examples_context,
inputs = [prompt]
)
with gr.Row():
gr.Examples(
label = 'Region1 Prompt example',
examples = examples_region1,
inputs = [region1_prompt]
)
gr.Examples(
label = 'Region2 Prompt example',
examples = examples_region2,
inputs = [region2_prompt]
)
# gr.Markdown(f"""
# ### 👯 Spatial Condition
# """)
# with gr.Group():
with gr.Tab('👯 Poses '):
gallery = gr.Gallery(label = "Select pose for characters",
value = [obj[1]for obj in pose_image_list],
elem_id = [obj[0]for obj in pose_image_list],
interactive=False, show_download_button=False,
preview=True, height = 400, object_fit="scale-down")
pose_image_name = gr.Textbox(label="You selected: ", interactive=False)
gallery.select(on_select, None, pose_image_name)
run_button = gr.Button("Run", scale=1)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Context Negative prompt",
max_lines=1,
value = 'saturated, cropped, worst quality, low quality',
visible=False,
)
region_neg_prompt = gr.Text(
label="Regional Negative prompt",
max_lines=1,
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
visible=False,
)
with gr.Row():
with gr.Column():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
alpha = gr.Slider(
label="Merge Weight",
minimum=1.2,
maximum=2.1,
step=0.3,
value=1.8,
)
with gr.Row():
sketch_adaptor_weight = gr.Slider(
label="Sketch Adapter Weight",
minimum = 0,
maximum = 1,
step=0.01,
value=0,
)
keypose_adaptor_weight = gr.Slider(
label="Keypose Adapter Weight",
minimum = 0.1,
maximum = 1,
step= 0.01,
value=1.0,
)
with gr.Column():
result = gr.Image(label="Result", show_label=False)
gr.Markdown(f"""
*Image generation may take longer for the first time you use a new combination of characters. <br />
This is because the model needs to load weights for each concept involved.*
""")
run_button.click(
fn = generate,
inputs = [region1_concept,
region2_concept,
prompt,
pose_image_name,
region1_prompt,
region2_prompt,
negative_prompt,
region_neg_prompt,
seed,
randomize_seed,
# sketch_condition,
# keypose_condition,
alpha,
sketch_adaptor_weight,
keypose_adaptor_weight
],
outputs = [result]
)
demo.queue().launch(share=True)