alvanlii's picture
Added more instructions
a47de98
from __future__ import annotations
import os
os.system("pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers")
os.system("pip install -e git+https://github.com/alvanli/RDM-Region-Aware-Diffusion-Model.git@main#egg=guided_diffusion")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
import math
import random
import gradio as gr
import torch
from PIL import Image, ImageOps
from run_edit import run_model
from cool_models import make_models
help_text = """"""
def main():
segmodel, model, diffusion, ldm, bert, clip_model, model_params = make_models()
def load_sample():
SAMPLE_IMAGE = "./flower1.jpg"
input_image = Image.open(SAMPLE_IMAGE)
from_text = "a flower"
instruction = "a sunflower"
negative_prompt = ""
seed = 42
guidance_scale = 5.0
clip_guidance_scale = 150
cutn = 16
l2_sim_lambda = 10_000
edited_image_1 = run_model(
segmodel, model, diffusion, ldm, bert, clip_model, model_params,
from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
)
return [
input_image, from_text, instruction, negative_prompt, seed, guidance_scale,
clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1
]
def generate(
input_image: Image.Image,
from_text: str,
instruction: str,
negative_prompt: str,
randomize_seed: bool,
seed: int,
guidance_scale: float,
clip_guidance_scale: float,
cutn: int,
l2_sim_lambda: float
):
seed = random.randint(0, 100000) if randomize_seed else seed
if instruction == "":
return [seed, input_image]
generator = torch.manual_seed(seed)
edited_image_1 = run_model(
segmodel, model, diffusion, ldm, bert, clip_model, model_params,
from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
)
return [seed, edited_image_1]
def reset():
return [
"Randomize Seed", 42, None, 5.0,
150, 16, 10000
]
with gr.Blocks() as demo:
gr.Markdown("""
#### RDM: Region-Aware Diffusion for Zero-shot Text-driven Image Editing
Original Github Repo: https://github.com/haha-lisa/RDM-Region-Aware-Diffusion-Model <br/>
Instructions: <br/>
- In the "From Text" field, specify the object you are trying to modify,
- In the "edit instruction" field, specify what you want that area to be turned into
""")
with gr.Row():
with gr.Column(scale=1, min_width=100):
generate_button = gr.Button("Generate")
with gr.Column(scale=1, min_width=100):
load_button = gr.Button("Load Example")
with gr.Column(scale=1, min_width=100):
reset_button = gr.Button("Reset")
with gr.Column(scale=3):
from_text = gr.Textbox(lines=1, label="From Text", interactive=True)
instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", interactive=True)
with gr.Row():
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
edited_image_1 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
# edited_image_2 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
input_image.style(height=512, width=512)
edited_image_1.style(height=512, width=512)
# edited_image_2.style(height=512, width=512)
with gr.Row():
# steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
guidance_scale = gr.Number(value=5.0, precision=1, label="Guidance Scale", interactive=True)
clip_guidance_scale = gr.Number(value=150, precision=1, label="Clip Guidance Scale", interactive=True)
cutn = gr.Number(value=16, precision=1, label="Number of Cuts", interactive=True)
l2_sim_lambda = gr.Number(value=10000, precision=1, label="L2 similarity to original image")
randomize_seed = gr.Radio(
["Fix Seed", "Randomize Seed"],
value="Randomize Seed",
type="index",
show_label=False,
interactive=True,
)
# use_ddim = gr.Checkbox(label="Use 50-step DDIM?", value=True)
# use_ddpm = gr.Checkbox(label="Use 50-step DDPM?", value=True)
gr.Markdown(help_text)
generate_button.click(
fn=generate,
inputs=[
input_image, from_text, instruction, negative_prompt, randomize_seed,
seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
],
outputs=[seed, edited_image_1],
)
load_button.click(
fn=load_sample,
inputs=[],
outputs=[input_image, from_text, instruction, negative_prompt, seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda, edited_image_1],
)
reset_button.click(
fn=reset,
inputs=[],
outputs=[
randomize_seed, seed, edited_image_1, guidance_scale,
clip_guidance_scale, cutn, l2_sim_lambda
],
)
demo.queue(concurrency_count=1)
demo.launch(share=False, server_name="0.0.0.0")
if __name__ == "__main__":
main()