ortha / app.py
ujin-song's picture
Update app.py
59a2ef6 verified
raw
history blame
13.9 kB
import gradio as gr
import numpy as np
import random
import torch
import io, json
from PIL import Image
import os.path
from weight_fusion import compose_concepts
from regionally_controlable_sampling import sample_image, build_model, prepare_text
device = "cuda" if torch.cuda.is_available() else "cpu"
power_device = "GPU" if torch.cuda.is_available() else "CPU"
MAX_SEED = 100_000
def generate(region1_concept,
region2_concept,
prompt,
pose_image_name,
region1_prompt,
region2_prompt,
negative_prompt,
region_neg_prompt,
seed,
randomize_seed,
sketch_adaptor_weight,
keypose_adaptor_weight
):
if region1_concept==region2_concept:
raise gr.Error("Please choose two different characters for merging weights.")
if len(pose_image_name)==0:
raise gr.Error("Please select one spatial condition!")
if len(region1_prompt)==0 or len(region1_prompt)==0:
raise gr.Error("Your regional prompt cannot be empty.")
if len(prompt)==0:
raise gr.Error("Your global prompt cannot be empty.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
region1_concept, region2_concept = region1_concept.lower(), region2_concept.lower()
pretrained_model = merge(region1_concept, region2_concept)
with open('multi-concept/pose_data/pose.json') as f:
d = json.load(f)
pose_image = {os.path.basename(obj['img_dir']):obj for obj in d}[pose_image_name]
# pose_image = {obj.pop('pose_id'):obj for obj in d}[int(pose_image_id)]
print(pose_image)
keypose_condition = pose_image['img_dir']
region1 = pose_image['region1']
region2 = pose_image['region2']
region_pos_prompt = "high resolution, best quality, highly detailed, sharp focus, expressive, 8k uhd, detailed, sophisticated"
region1_prompt = f'<{region1_concept}1> <{region1_concept}2>, {region1_prompt}, {region_pos_prompt}'
region2_prompt = f'<{region2_concept}1> <{region2_concept}2>, {region2_prompt}, {region_pos_prompt}'
prompt_rewrite=f"{region1_prompt}-*-{region_neg_prompt}-*-{region1}|{region2_prompt}-*-{region_neg_prompt}-*-{region2}"
print(prompt_rewrite)
prompt+=", Disney style photo, High resolution"
result = infer(pretrained_model,
prompt,
prompt_rewrite,
negative_prompt,
seed,
keypose_condition,
keypose_adaptor_weight,
# sketch_condition,
# sketch_adaptor_weight,
)
return result
def merge(concept1, concept2):
device = "cuda" if torch.cuda.is_available() else "cpu"
c1, c2 = sorted([concept1, concept2])
assert c1!=c2
merge_name = c1+'_'+c2
save_path = f'experiments/multi-concept/{merge_name}'
if os.path.isdir(save_path):
print(f'{save_path} already exists. Collecting merged weights from existing weights...')
else:
os.makedirs(save_path)
json_path = os.path.join(save_path,'merge_config.json')
alpha = 1.8
data = [
{
"lora_path": f"experiments/single-concept/{c1}/models/edlora_model-latest.pth",
"unet_alpha": alpha,
"text_encoder_alpha": alpha,
"concept_name": f"<{c1}1> <{c1}2>"
},
{
"lora_path": f"experiments/single-concept/{c2}/models/edlora_model-latest.pth",
"unet_alpha": alpha,
"text_encoder_alpha": alpha,
"concept_name": f"<{c2}1> <{c2}2>"
}
]
with io.open(json_path,'w',encoding='utf8') as outfile:
json.dump(data, outfile, indent = 4, ensure_ascii=False)
compose_concepts(
concept_cfg=json_path,
optimize_textenc_iters=500,
optimize_unet_iters=50,
pretrained_model_path="nitrosocke/mo-di-diffusion",
save_path=save_path,
suffix='base',
device=device,
)
print(f'Merged weight for {c1}+{c2} saved in {save_path}!\n\n')
modelbase_path = os.path.join(save_path,'combined_model_base')
assert os.path.isdir(modelbase_path)
# save_path = 'experiments/multi-concept/elsa_moana_weight18/combined_model_base'
return modelbase_path
def infer(pretrained_model,
prompt,
prompt_rewrite,
negative_prompt='',
seed=16141,
keypose_condition=None,
keypose_adaptor_weight=1.0,
sketch_condition=None,
sketch_adaptor_weight=0.0,
region_sketch_adaptor_weight='',
region_keypose_adaptor_weight=''
):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pipe = build_model(pretrained_model, device)
if sketch_condition is not None and os.path.exists(sketch_condition):
sketch_condition = Image.open(sketch_condition).convert('L')
width_sketch, height_sketch = sketch_condition.size
print('use sketch condition')
else:
sketch_condition, width_sketch, height_sketch = None, 0, 0
print('skip sketch condition')
if keypose_condition is not None and os.path.exists(keypose_condition):
keypose_condition = Image.open(keypose_condition).convert('RGB')
width_pose, height_pose = keypose_condition.size
print('use pose condition')
else:
keypose_condition, width_pose, height_pose = None, 0, 0
print('skip pose condition')
if width_sketch != 0 and width_pose != 0:
assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size'
width, height = max(width_pose, width_sketch), max(height_pose, height_sketch)
kwargs = {
'sketch_condition': sketch_condition,
'keypose_condition': keypose_condition,
'height': height,
'width': width,
}
prompts = [prompt]
prompts_rewrite = [prompt_rewrite]
input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)]
save_prompt = input_prompt[0][0]
print(save_prompt)
image = sample_image(
pipe,
input_prompt=input_prompt,
input_neg_prompt=[negative_prompt] * len(input_prompt),
generator=torch.Generator(device).manual_seed(seed),
sketch_adaptor_weight=sketch_adaptor_weight,
region_sketch_adaptor_weight=region_sketch_adaptor_weight,
keypose_adaptor_weight=keypose_adaptor_weight,
region_keypose_adaptor_weight=region_keypose_adaptor_weight,
**kwargs)
return image[0]
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
return evt.value['image']['orig_name']
examples_context = [
'walking on the busy streets of New York',
'in the forest',
'in the style of cyberpunk'
]
examples_region1 = ['In a casual t-shirt', 'wearing jeans']
examples_region2 = ['smiling, wearing a blue hoodie']
with open('multi-concept/pose_data/pose.json') as f:
d = json.load(f)
pose_image_list = [(obj['img_id'],obj['img_dir']) for obj in d]
css="""
#col-container {
margin: 0 auto;
max-width: 600px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(f"""
# Orthogonal Adaptation
Describe your world with a **πŸͺ„ text prompt (global and local)** and choose two characters to merge.
Select their **πŸ‘― poses (spatial conditions)** for regionally controllable sampling to generate a unique image using our model.
Let your creativity run wild! (Currently running on : {power_device} )
""")
with gr.Row():
with gr.Column(elem_id="col-container"):
# gr.Markdown(f"""
# ### πŸͺ„ Global and Region prompts
# """)
# with gr.Group():
with gr.Tab('πŸͺ„ Global and Region prompts'):
prompt = gr.Text(
label="ContextPrompt",
show_label=False,
max_lines=1,
placeholder="Enter your global context prompt",
container=False,
)
with gr.Row():
concept_list = ["Elsa", "Moana", "Woody", "Rapunzel", "Elastigirl",
"Linguini", "Raya", "Hiro", "Mirabel", "Miguel"]
region1_concept = gr.Dropdown(
concept_list,
label="Character 1",
# info="Will add more characters later!"
)
region2_concept = gr.Dropdown(
concept_list,
label="Character 2",
# info="Will add more characters later!"
)
with gr.Row():
region1_prompt = gr.Textbox(
label="Region1 Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your regional prompt for character 1",
container=False,
)
region2_prompt = gr.Textbox(
label="Region2 Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your regional prompt for character 2",
container=False,
)
gr.Examples(
label = 'Global Prompt example',
examples = examples_context,
inputs = [prompt]
)
with gr.Row():
gr.Examples(
label = 'Region1 Prompt example',
examples = examples_region1,
inputs = [region1_prompt]
)
gr.Examples(
label = 'Region2 Prompt example',
examples = [examples_region2],
inputs = [region2_prompt]
)
# gr.Markdown(f"""
# ### πŸ‘― Spatial Condition
# """)
# with gr.Group():
with gr.Tab('πŸ‘― Spatial Condition '):
gallery = gr.Gallery(label = "Select pose for characters",
value = [obj[1]for obj in pose_image_list],
elem_id = [obj[0]for obj in pose_image_list],
interactive=False, show_download_button=False,
preview=True, height = 400, object_fit="scale-down")
pose_image_name = gr.Textbox(label="You selected: ", interactive=False)
gallery.select(on_select, None, pose_image_name)
run_button = gr.Button("Run", scale=1)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Context Negative prompt",
max_lines=1,
value = 'saturated, cropped, worst quality, low quality',
visible=False,
)
region_neg_prompt = gr.Text(
label="Regional Negative prompt",
max_lines=1,
value = 'shirtless, nudity, saturated, cropped, worst quality, low quality',
visible=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
sketch_adaptor_weight = gr.Slider(
label="Sketch Adapter Weight",
minimum = 0,
maximum = 1,
step=0.01,
value=0,
)
keypose_adaptor_weight = gr.Slider(
label="Keypose Adapter Weight",
minimum = 0.1,
maximum = 1,
step= 0.01,
value=1.0,
)
with gr.Column():
result = gr.Image(label="Result", show_label=False)
gr.Markdown(f"""
*Image generation may take longer for the first time you use a new combination of characters. <br />
This is because the model needs to load weights for each concept involved.*
""")
run_button.click(
fn = generate,
inputs = [region1_concept,
region2_concept,
prompt,
pose_image_name,
region1_prompt,
region2_prompt,
negative_prompt,
region_neg_prompt,
seed,
randomize_seed,
# sketch_condition,
# keypose_condition,
sketch_adaptor_weight,
keypose_adaptor_weight
],
outputs = [result]
)
demo.queue().launch(share=True)